JDhruv14 commited on
Commit
8a5aaef
·
verified ·
1 Parent(s): 0037da4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -31
app.py CHANGED
@@ -1,18 +1,32 @@
1
  import os, torch, gradio as gr, spaces
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
 
3
 
4
- MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/Gita-FT-v2-Qwen2.5-3B")
 
 
 
 
 
5
 
6
- # Load once (CPU until first call; device_map will move to GPU on first run)
7
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
8
  model = AutoModelForCausalLM.from_pretrained(
9
- MODEL_ID,
10
  device_map="auto",
11
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
12
  trust_remote_code=True,
13
  )
 
 
 
 
 
 
 
 
 
 
14
 
15
- def _msgs_from_history(history, system_text):
16
  msgs = []
17
  if system_text:
18
  msgs.append({"role": "system", "content": system_text})
@@ -23,16 +37,9 @@ def _msgs_from_history(history, system_text):
23
  msgs.append({"role": "assistant", "content": assistant})
24
  return msgs
25
 
26
- def _eos_ids(tok):
27
- ids = {tok.eos_token_id}
28
- im_end = tok.convert_tokens_to_ids("<|im_end|>")
29
- if im_end is not None:
30
- ids.add(im_end)
31
- return list(ids)
32
-
33
- @spaces.GPU(duration=120) # REQUIRED for ZeroGPU; remove if using standard GPU hardware
34
- def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
35
- msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
36
  prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
37
  inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
38
 
@@ -40,43 +47,39 @@ def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)
40
  do_sample=True,
41
  temperature=float(temperature),
42
  top_p=float(top_p),
43
- max_new_tokens=int(max_new),
44
- min_new_tokens=int(min_new),
45
  repetition_penalty=1.02,
46
  no_repeat_ngram_size=3,
47
  eos_token_id=_eos_ids(tokenizer),
48
  pad_token_id=tokenizer.eos_token_id,
49
  )
50
  with torch.no_grad():
51
- out = model.generate(**inputs, generation_config=gen_cfg)
52
 
53
- # slice off the prompt so we show only the assistant reply
54
- new_tokens = out[:, inputs["input_ids"].shape[1]:]
55
  reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
56
  return reply
57
 
58
  with gr.Blocks() as demo:
59
  gr.Markdown(
60
- "<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B Fine-tuned)</h1>"
61
  "<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
62
  )
63
  system_box = gr.Textbox(
64
- value="Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.",
65
  label="System prompt",
66
  )
67
  temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
68
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
69
- max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
70
- min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")
71
 
72
- chat = gr.ChatInterface(
73
  fn=lambda m, h: chat_fn(m, h, system_box.value, temperature.value, top_p.value, max_new.value, min_new.value),
74
- title=None,
75
  additional_inputs=[system_box, temperature, top_p, max_new, min_new],
76
- retry_btn="Regenerate",
77
- undo_btn="Undo Last",
78
- clear_btn="Clear",
79
- queue=True, # queue is recommended (and required for ZeroGPU concurrency)
80
  )
81
 
82
  if __name__ == "__main__":
 
1
  import os, torch, gradio as gr, spaces
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3
+ from peft import PeftModel
4
 
5
+ # ---- IDs (can override from Space Secrets) ----
6
+ BASE_ID = os.getenv("BASE_ID", "Qwen/Qwen2.5-3B-Instruct")
7
+ ADAPTER_ID = os.getenv("ADAPTER_ID", "JDhruv14/Gita-FT-v2-Qwen2.5-3B")
8
+
9
+ # ---- Load tokenizer & base model ----
10
+ tokenizer = AutoTokenizer.from_pretrained(BASE_ID, trust_remote_code=True)
11
 
 
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
+ BASE_ID,
14
  device_map="auto",
15
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
16
  trust_remote_code=True,
17
  )
18
+ # Apply LoRA adapter
19
+ model = PeftModel.from_pretrained(model, ADAPTER_ID)
20
+ model.eval()
21
+
22
+ def _eos_ids(tok):
23
+ ids = {tok.eos_token_id}
24
+ im_end = tok.convert_tokens_to_ids("<|im_end|>")
25
+ if im_end is not None:
26
+ ids.add(im_end)
27
+ return list(ids)
28
 
29
+ def _format_history(history, system_text):
30
  msgs = []
31
  if system_text:
32
  msgs.append({"role": "system", "content": system_text})
 
37
  msgs.append({"role": "assistant", "content": assistant})
38
  return msgs
39
 
40
+ @spaces.GPU(duration=120) # keep for ZeroGPU; remove this decorator if using a normal GPU Space
41
+ def chat_fn(message, history, system_text, temperature, top_p, max_new_tokens, min_new_tokens):
42
+ msgs = _format_history(history, system_text) + [{"role": "user", "content": message}]
 
 
 
 
 
 
 
43
  prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
44
  inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
45
 
 
47
  do_sample=True,
48
  temperature=float(temperature),
49
  top_p=float(top_p),
50
+ max_new_tokens=int(max_new_tokens),
51
+ min_new_tokens=int(min_new_tokens),
52
  repetition_penalty=1.02,
53
  no_repeat_ngram_size=3,
54
  eos_token_id=_eos_ids(tokenizer),
55
  pad_token_id=tokenizer.eos_token_id,
56
  )
57
  with torch.no_grad():
58
+ outputs = model.generate(**inputs, generation_config=gen_cfg)
59
 
60
+ # show only the assistant reply (slice off the prompt)
61
+ new_tokens = outputs[:, inputs["input_ids"].shape[1]:]
62
  reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
63
  return reply
64
 
65
  with gr.Blocks() as demo:
66
  gr.Markdown(
67
+ "<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B + LoRA)</h1>"
68
  "<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
69
  )
70
  system_box = gr.Textbox(
71
+ value="Reply in the user’s language with 2–3 concrete points (200–400 words); cite Gita verses when relevant.",
72
  label="System prompt",
73
  )
74
  temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
75
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
76
+ max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
77
+ min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")
78
 
79
+ gr.ChatInterface(
80
  fn=lambda m, h: chat_fn(m, h, system_box.value, temperature.value, top_p.value, max_new.value, min_new.value),
 
81
  additional_inputs=[system_box, temperature, top_p, max_new, min_new],
82
+ retry_btn="Regenerate", undo_btn="Undo Last", clear_btn="Clear", queue=True
 
 
 
83
  )
84
 
85
  if __name__ == "__main__":