JDhruv14 commited on
Commit
f03b213
·
verified ·
1 Parent(s): 6a4053c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -14
app.py CHANGED
@@ -12,6 +12,10 @@ model = AutoModelForCausalLM.from_pretrained(
12
  trust_remote_code=True,
13
  )
14
 
 
 
 
 
15
  def _msgs_from_history(history, system_text):
16
  msgs = []
17
  if system_text:
@@ -24,23 +28,29 @@ def _msgs_from_history(history, system_text):
24
  return msgs
25
 
26
  def _eos_ids(tok):
27
- ids = {tok.eos_token_id}
28
- im_end = tok.convert_tokens_to_ids("<|im_end|>")
29
- if im_end is not None:
30
- ids.add(im_end)
 
 
 
 
 
 
 
 
 
 
31
  return list(ids)
32
 
33
- @spaces.GPU() # REQUIRED for ZeroGPU; remove if using standard GPU hardware
34
- def gradio_fn(message, history):
35
- response = infer_text(history + [(message, None)])
36
- return response
37
-
38
  def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
39
  msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
40
  prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
41
  inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
42
 
43
- gen_cfg = GenerationConfig(
 
44
  do_sample=True,
45
  temperature=float(temperature),
46
  top_p=float(top_p),
@@ -48,9 +58,13 @@ def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)
48
  min_new_tokens=int(min_new),
49
  repetition_penalty=1.02,
50
  no_repeat_ngram_size=3,
51
- eos_token_id=_eos_ids(tokenizer),
52
- pad_token_id=tokenizer.eos_token_id,
53
  )
 
 
 
 
 
54
  with torch.no_grad():
55
  out = model.generate(**inputs, generation_config=gen_cfg)
56
 
@@ -59,11 +73,17 @@ def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)
59
  reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
60
  return reply
61
 
 
 
 
 
 
62
  with gr.Blocks() as demo:
63
  gr.Markdown(
64
  "<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B Fine-tuned)</h1>"
65
  "<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
66
  )
 
67
  system_box = gr.Textbox(
68
  value="Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.",
69
  label="System prompt",
@@ -75,6 +95,7 @@ with gr.Blocks() as demo:
75
 
76
  chat = gr.ChatInterface(
77
  fn=gradio_fn,
 
78
  examples=[
79
  "Hello!",
80
  "How can I overcome fear of failure?",
@@ -82,8 +103,7 @@ with gr.Blocks() as demo:
82
  "What can I do to stop overthinking?"
83
  ],
84
  chatbot=gr.Chatbot(elem_classes="chatbot"),
85
- theme="compact",
86
  )
87
 
88
  if __name__ == "__main__":
89
- demo.launch()
 
12
  trust_remote_code=True,
13
  )
14
 
15
+ # Ensure pad token exists (many chat models reuse EOS as PAD)
16
+ if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
17
+ tokenizer.pad_token = tokenizer.eos_token
18
+
19
  def _msgs_from_history(history, system_text):
20
  msgs = []
21
  if system_text:
 
28
  return msgs
29
 
30
  def _eos_ids(tok):
31
+ # Support ints/lists and optional <|im_end|>
32
+ ids = set()
33
+ if tok.eos_token_id is not None:
34
+ if isinstance(tok.eos_token_id, (list, tuple)):
35
+ ids.update(tok.eos_token_id)
36
+ else:
37
+ ids.add(tok.eos_token_id)
38
+ try:
39
+ im_end = tok.convert_tokens_to_ids("<|im_end|>")
40
+ if im_end is not None and im_end != tok.unk_token_id:
41
+ ids.add(im_end)
42
+ except Exception:
43
+ pass
44
+ # Fallback: if still empty, just skip setting eos_token_id in GenerationConfig
45
  return list(ids)
46
 
 
 
 
 
 
47
  def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
48
  msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
49
  prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
50
  inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
51
 
52
+ eos = _eos_ids(tokenizer)
53
+ gen_cfg_kwargs = dict(
54
  do_sample=True,
55
  temperature=float(temperature),
56
  top_p=float(top_p),
 
58
  min_new_tokens=int(min_new),
59
  repetition_penalty=1.02,
60
  no_repeat_ngram_size=3,
61
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
 
62
  )
63
+ if eos:
64
+ gen_cfg_kwargs["eos_token_id"] = eos
65
+
66
+ gen_cfg = GenerationConfig(**gen_cfg_kwargs)
67
+
68
  with torch.no_grad():
69
  out = model.generate(**inputs, generation_config=gen_cfg)
70
 
 
73
  reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
74
  return reply
75
 
76
+ # Wrap for ChatInterface + ZeroGPU
77
+ @spaces.GPU() # REQUIRED for ZeroGPU; remove if using standard GPU hardware
78
+ def gradio_fn(message, history, system_text, temperature, top_p, max_new, min_new):
79
+ return chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)
80
+
81
  with gr.Blocks() as demo:
82
  gr.Markdown(
83
  "<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B Fine-tuned)</h1>"
84
  "<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
85
  )
86
+
87
  system_box = gr.Textbox(
88
  value="Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.",
89
  label="System prompt",
 
95
 
96
  chat = gr.ChatInterface(
97
  fn=gradio_fn,
98
+ additional_inputs=[system_box, temperature, top_p, max_new, min_new],
99
  examples=[
100
  "Hello!",
101
  "How can I overcome fear of failure?",
 
103
  "What can I do to stop overthinking?"
104
  ],
105
  chatbot=gr.Chatbot(elem_classes="chatbot"),
 
106
  )
107
 
108
  if __name__ == "__main__":
109
+ demo.launch()