File size: 4,300 Bytes
33dd5ba
 
8a5aaef
33dd5ba
26009d1
33dd5ba
 
 
 
318503e
33dd5ba
 
318503e
26009d1
f03b213
 
 
 
33dd5ba
 
 
 
 
 
 
 
 
 
26009d1
33dd5ba
f03b213
 
 
 
 
 
 
 
 
 
 
 
 
 
33dd5ba
26009d1
33dd5ba
 
 
 
26009d1
f03b213
 
33dd5ba
 
 
 
 
 
 
f03b213
33dd5ba
f03b213
 
 
 
 
33dd5ba
 
26009d1
33dd5ba
 
 
 
0d61b36
f03b213
 
 
 
 
33dd5ba
e8c693f
33dd5ba
 
318503e
f03b213
33dd5ba
 
 
318503e
33dd5ba
 
 
 
e8c693f
33dd5ba
be1224a
f03b213
be1224a
 
 
 
 
 
 
33dd5ba
26009d1
 
f03b213
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/merged_model")

# Load once (CPU until first call; device_map will move to GPU on first run)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
    trust_remote_code=True,
)

# Ensure pad token exists (many chat models reuse EOS as PAD)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

def _msgs_from_history(history, system_text):
    msgs = []
    if system_text:
        msgs.append({"role": "system", "content": system_text})
    for user, assistant in history:
        if user:
            msgs.append({"role": "user", "content": user})
        if assistant:
            msgs.append({"role": "assistant", "content": assistant})
    return msgs

def _eos_ids(tok):
    # Support ints/lists and optional <|im_end|>
    ids = set()
    if tok.eos_token_id is not None:
        if isinstance(tok.eos_token_id, (list, tuple)):
            ids.update(tok.eos_token_id)
        else:
            ids.add(tok.eos_token_id)
    try:
        im_end = tok.convert_tokens_to_ids("<|im_end|>")
        if im_end is not None and im_end != tok.unk_token_id:
            ids.add(im_end)
    except Exception:
        pass
    # Fallback: if still empty, just skip setting eos_token_id in GenerationConfig
    return list(ids)

def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
    msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
    prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    eos = _eos_ids(tokenizer)
    gen_cfg_kwargs = dict(
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        max_new_tokens=int(max_new),
        min_new_tokens=int(min_new),
        repetition_penalty=1.02,
        no_repeat_ngram_size=3,
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
    )
    if eos:
        gen_cfg_kwargs["eos_token_id"] = eos

    gen_cfg = GenerationConfig(**gen_cfg_kwargs)

    with torch.no_grad():
        out = model.generate(**inputs, generation_config=gen_cfg)

    # slice off the prompt so we show only the assistant reply
    new_tokens = out[:, inputs["input_ids"].shape[1]:]
    reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return reply

# Wrap for ChatInterface + ZeroGPU
@spaces.GPU()  # REQUIRED for ZeroGPU; remove if using standard GPU hardware
def gradio_fn(message, history, system_text, temperature, top_p, max_new, min_new):
    return chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)

with gr.Blocks() as demo:
    gr.Markdown(
        "<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B Fine-tuned)</h1>"
        "<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
    )

    system_box = gr.Textbox(
        value="Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.",
        label="System prompt",
    )
    temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
    top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
    max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
    min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")

    chat = gr.ChatInterface(
        fn=gradio_fn,
        additional_inputs=[system_box, temperature, top_p, max_new, min_new],
        examples=[
            "Hello!",
            "How can I overcome fear of failure?",
            "How do I forgive someone who hurt me deeply?",
            "What can I do to stop overthinking?"
        ],
        chatbot=gr.Chatbot(elem_classes="chatbot"),
    )

if __name__ == "__main__":
    demo.launch()