File size: 3,876 Bytes
318503e
 
8a5aaef
26009d1
8a5aaef
 
 
 
 
 
26009d1
318503e
8a5aaef
318503e
 
 
 
8a5aaef
 
 
 
 
 
 
 
 
 
26009d1
8a5aaef
318503e
 
 
 
 
 
 
 
 
26009d1
8a5aaef
 
 
318503e
 
26009d1
318503e
 
 
 
8a5aaef
 
318503e
 
 
 
 
 
8a5aaef
26009d1
8a5aaef
 
318503e
 
26009d1
 
318503e
8a5aaef
318503e
 
0d61b36
 
8a5aaef
318503e
 
 
8a5aaef
 
 
26009d1
8a5aaef
114b1bc
318503e
114b1bc
0d61b36
114b1bc
 
 
0d61b36
318503e
26009d1
 
 
0d61b36
114b1bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from peft import PeftModel

# ---- IDs (can override from Space Secrets) ----
BASE_ID    = os.getenv("BASE_ID",    "Qwen/Qwen2.5-3B-Instruct")
ADAPTER_ID = os.getenv("ADAPTER_ID", "JDhruv14/Gita-FT-v2-Qwen2.5-3B")

# ---- Load tokenizer & base model ----
tokenizer = AutoTokenizer.from_pretrained(BASE_ID, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    BASE_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
    trust_remote_code=True,
)
# Apply LoRA adapter
model = PeftModel.from_pretrained(model, ADAPTER_ID)
model.eval()

def _eos_ids(tok):
    ids = {tok.eos_token_id}
    im_end = tok.convert_tokens_to_ids("<|im_end|>")
    if im_end is not None:
        ids.add(im_end)
    return list(ids)

def _format_history(history, system_text):
    msgs = []
    if system_text:
        msgs.append({"role": "system", "content": system_text})
    for user, assistant in history:
        if user:
            msgs.append({"role": "user", "content": user})
        if assistant:
            msgs.append({"role": "assistant", "content": assistant})
    return msgs

@spaces.GPU(duration=120)   # keep for ZeroGPU; remove this decorator if using a normal GPU Space
def chat_fn(message, history, system_text, temperature, top_p, max_new_tokens, min_new_tokens):
    msgs = _format_history(history, system_text) + [{"role": "user", "content": message}]
    prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    gen_cfg = GenerationConfig(
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        max_new_tokens=int(max_new_tokens),
        min_new_tokens=int(min_new_tokens),
        repetition_penalty=1.02,
        no_repeat_ngram_size=3,
        eos_token_id=_eos_ids(tokenizer),
        pad_token_id=tokenizer.eos_token_id,
    )
    with torch.no_grad():
        outputs = model.generate(**inputs, generation_config=gen_cfg)

    # show only the assistant reply (slice off the prompt)
    new_tokens = outputs[:, inputs["input_ids"].shape[1]:]
    reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return reply

with gr.Blocks() as demo:
    gr.Markdown(
        "<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B + LoRA)</h1>"
        "<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
    )

    system_box  = gr.Textbox(
        value="Reply in the user’s language with 2–3 concrete points (200–400 words); cite Gita verses when relevant.",
        label="System prompt",
    )
    temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
    top_p       = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
    max_new     = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
    min_new     = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")

    gr.ChatInterface(
        fn=chat_fn,  # def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)
        additional_inputs=[system_box, temperature, top_p, max_new, min_new],
        chatbot=gr.Chatbot(height=520, type="tuples"),  # keep tuple history; no behavior change
        examples=[
            ["How do I practice Nishkama Karma at work?", system_box.value, 0.7, 0.9, 512, 160],
            ["What does 3.19 teach about duty without attachment?", system_box.value, 0.7, 0.9, 512, 160],
            ["How to overcome fear of failure according to the Gita?", system_box.value, 0.7, 0.9, 512, 160],
        ],
    )

if __name__ == "__main__":
    demo.launch()