File size: 3,530 Bytes
318503e
 
8a5aaef
26009d1
8a5aaef
 
 
 
 
 
26009d1
318503e
8a5aaef
318503e
 
 
 
8a5aaef
 
 
 
 
 
 
 
 
 
26009d1
8a5aaef
318503e
 
 
 
 
 
 
 
 
26009d1
8a5aaef
 
 
318503e
 
26009d1
318503e
 
 
 
8a5aaef
 
318503e
 
 
 
 
 
8a5aaef
26009d1
8a5aaef
 
318503e
 
26009d1
 
318503e
8a5aaef
318503e
 
 
8a5aaef
318503e
 
 
8a5aaef
 
 
26009d1
8a5aaef
318503e
 
8a5aaef
318503e
26009d1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from peft import PeftModel

# ---- IDs (can override from Space Secrets) ----
BASE_ID    = os.getenv("BASE_ID",    "Qwen/Qwen2.5-3B-Instruct")
ADAPTER_ID = os.getenv("ADAPTER_ID", "JDhruv14/Gita-FT-v2-Qwen2.5-3B")

# ---- Load tokenizer & base model ----
tokenizer = AutoTokenizer.from_pretrained(BASE_ID, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    BASE_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
    trust_remote_code=True,
)
# Apply LoRA adapter
model = PeftModel.from_pretrained(model, ADAPTER_ID)
model.eval()

def _eos_ids(tok):
    ids = {tok.eos_token_id}
    im_end = tok.convert_tokens_to_ids("<|im_end|>")
    if im_end is not None:
        ids.add(im_end)
    return list(ids)

def _format_history(history, system_text):
    msgs = []
    if system_text:
        msgs.append({"role": "system", "content": system_text})
    for user, assistant in history:
        if user:
            msgs.append({"role": "user", "content": user})
        if assistant:
            msgs.append({"role": "assistant", "content": assistant})
    return msgs

@spaces.GPU(duration=120)   # keep for ZeroGPU; remove this decorator if using a normal GPU Space
def chat_fn(message, history, system_text, temperature, top_p, max_new_tokens, min_new_tokens):
    msgs = _format_history(history, system_text) + [{"role": "user", "content": message}]
    prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    gen_cfg = GenerationConfig(
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        max_new_tokens=int(max_new_tokens),
        min_new_tokens=int(min_new_tokens),
        repetition_penalty=1.02,
        no_repeat_ngram_size=3,
        eos_token_id=_eos_ids(tokenizer),
        pad_token_id=tokenizer.eos_token_id,
    )
    with torch.no_grad():
        outputs = model.generate(**inputs, generation_config=gen_cfg)

    # show only the assistant reply (slice off the prompt)
    new_tokens = outputs[:, inputs["input_ids"].shape[1]:]
    reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return reply

with gr.Blocks() as demo:
    gr.Markdown(
        "<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B + LoRA)</h1>"
        "<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
    )
    system_box = gr.Textbox(
        value="Reply in the user’s language with 2–3 concrete points (200–400 words); cite Gita verses when relevant.",
        label="System prompt",
    )
    temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
    top_p       = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
    max_new     = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
    min_new     = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")

    gr.ChatInterface(
        fn=lambda m, h: chat_fn(m, h, system_box.value, temperature.value, top_p.value, max_new.value, min_new.value),
        additional_inputs=[system_box, temperature, top_p, max_new, min_new],
        retry_btn="Regenerate", undo_btn="Undo Last", clear_btn="Clear", queue=True
    )

if __name__ == "__main__":
    demo.launch()