File size: 3,394 Bytes
33dd5ba
 
8a5aaef
33dd5ba
26009d1
33dd5ba
 
 
 
318503e
33dd5ba
 
318503e
26009d1
33dd5ba
 
 
 
 
 
 
 
 
 
26009d1
33dd5ba
 
 
 
 
 
26009d1
33dd5ba
 
 
 
 
26009d1
33dd5ba
 
 
 
 
 
 
 
 
 
 
 
 
26009d1
33dd5ba
 
 
 
0d61b36
33dd5ba
e8c693f
33dd5ba
 
318503e
33dd5ba
 
 
318503e
33dd5ba
 
 
 
e8c693f
33dd5ba
 
 
 
 
 
 
 
 
26009d1
 
b81b02b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/merged_model")

# Load once (CPU until first call; device_map will move to GPU on first run)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
    trust_remote_code=True,
)

def _msgs_from_history(history, system_text):
    msgs = []
    if system_text:
        msgs.append({"role": "system", "content": system_text})
    for user, assistant in history:
        if user:
            msgs.append({"role": "user", "content": user})
        if assistant:
            msgs.append({"role": "assistant", "content": assistant})
    return msgs

def _eos_ids(tok):
    ids = {tok.eos_token_id}
    im_end = tok.convert_tokens_to_ids("<|im_end|>")
    if im_end is not None:
        ids.add(im_end)
    return list(ids)

@spaces.GPU(duration=120)  # REQUIRED for ZeroGPU; remove if using standard GPU hardware
def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
    msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
    prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    gen_cfg = GenerationConfig(
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        max_new_tokens=int(max_new),
        min_new_tokens=int(min_new),
        repetition_penalty=1.02,
        no_repeat_ngram_size=3,
        eos_token_id=_eos_ids(tokenizer),
        pad_token_id=tokenizer.eos_token_id,
    )
    with torch.no_grad():
        out = model.generate(**inputs, generation_config=gen_cfg)

    # slice off the prompt so we show only the assistant reply
    new_tokens = out[:, inputs["input_ids"].shape[1]:]
    reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return reply

with gr.Blocks() as demo:
    gr.Markdown(
        "<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B Fine-tuned)</h1>"
        "<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
    )
    system_box = gr.Textbox(
        value="Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.",
        label="System prompt",
    )
    temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
    top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
    max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
    min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")

    chat = gr.ChatInterface(
        fn=lambda m, h: chat_fn(m, h, system_box.value, temperature.value, top_p.value, max_new.value, min_new.value),
        title=None,
        additional_inputs=[system_box, temperature, top_p, max_new, min_new],
        retry_btn="Regenerate",
        undo_btn="Undo Last",
        clear_btn="Clear",
        queue=True,  # queue is recommended (and required for ZeroGPU concurrency)
    )

if __name__ == "__main__":
    demo.launch()