File size: 3,404 Bytes
318503e
 
26009d1
318503e
26009d1
318503e
 
 
 
 
 
 
 
26009d1
318503e
 
 
 
 
 
 
 
 
 
26009d1
318503e
 
 
 
 
 
26009d1
318503e
 
 
 
 
26009d1
318503e
 
 
 
 
 
 
 
 
 
 
 
 
26009d1
318503e
 
 
 
26009d1
 
318503e
 
 
 
 
 
 
 
 
 
 
 
26009d1
318503e
 
 
 
 
 
 
 
 
26009d1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/Gita-FT-v2-Qwen2.5-3B")

# Load once (CPU until first call; device_map will move to GPU on first run)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
    trust_remote_code=True,
)

def _msgs_from_history(history, system_text):
    msgs = []
    if system_text:
        msgs.append({"role": "system", "content": system_text})
    for user, assistant in history:
        if user:
            msgs.append({"role": "user", "content": user})
        if assistant:
            msgs.append({"role": "assistant", "content": assistant})
    return msgs

def _eos_ids(tok):
    ids = {tok.eos_token_id}
    im_end = tok.convert_tokens_to_ids("<|im_end|>")
    if im_end is not None:
        ids.add(im_end)
    return list(ids)

@spaces.GPU(duration=120)  # REQUIRED for ZeroGPU; remove if using standard GPU hardware
def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
    msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
    prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    gen_cfg = GenerationConfig(
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        max_new_tokens=int(max_new),
        min_new_tokens=int(min_new),
        repetition_penalty=1.02,
        no_repeat_ngram_size=3,
        eos_token_id=_eos_ids(tokenizer),
        pad_token_id=tokenizer.eos_token_id,
    )
    with torch.no_grad():
        out = model.generate(**inputs, generation_config=gen_cfg)

    # slice off the prompt so we show only the assistant reply
    new_tokens = out[:, inputs["input_ids"].shape[1]:]
    reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return reply

with gr.Blocks() as demo:
    gr.Markdown(
        "<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B Fine-tuned)</h1>"
        "<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
    )
    system_box = gr.Textbox(
        value="Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.",
        label="System prompt",
    )
    temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
    top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
    max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
    min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")

    chat = gr.ChatInterface(
        fn=lambda m, h: chat_fn(m, h, system_box.value, temperature.value, top_p.value, max_new.value, min_new.value),
        title=None,
        additional_inputs=[system_box, temperature, top_p, max_new, min_new],
        retry_btn="Regenerate",
        undo_btn="Undo Last",
        clear_btn="Clear",
        queue=True,  # queue is recommended (and required for ZeroGPU concurrency)
    )

if __name__ == "__main__":
    demo.launch()