JDhruv14 commited on
Commit
33dd5ba
·
verified ·
1 Parent(s): b815887

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -104
app.py CHANGED
@@ -1,121 +1,83 @@
1
- import torch
2
- torch._dynamo.config.disable = True
3
- from collections.abc import Iterator
4
- from transformers import (
5
- Gemma3ForConditionalGeneration,
6
- TextIteratorStreamer,
7
- Gemma3Processor,
8
- Gemma3nForConditionalGeneration,
9
- )
10
- import gradio as gr
11
- import os
12
- import spaces
13
 
14
- # Load environment variables
15
- model_3n_id = os.getenv("MODEL_3N_ID", "JDhruv14/merged_model")
16
 
17
- # Load model and processor
18
- model_3n = Gemma3nForConditionalGeneration.from_pretrained(
19
- model_3n_id,
20
- dtype=torch.bfloat16,
21
  device_map="auto",
22
- attn_implementation="eager"
 
23
  )
24
- input_processor = Gemma3Processor.from_pretrained(model_3n_id)
25
 
26
- def infer_text(messages, max_new_tokens=300, temperature=1.0, top_p=0.95, top_k=64, repetition_penalty=1.1):
27
- chat_template = []
28
- for turn in messages:
29
- if turn[0]:
30
- chat_template.append({"role": "user", "content": [{"type": "text", "text": turn[0]}]})
31
- if turn[1]:
32
- chat_template.append({"role": "assistant", "content": [{"type": "text", "text": turn[1]}]})
33
- chat_template.append({"role": "assistant", "content": [{"type": "text", "text": ""}]})
 
 
34
 
35
- inputs = input_processor.apply_chat_template(
36
- chat_template,
37
- add_generation_prompt=True,
38
- tokenize=True,
39
- return_dict=True,
40
- return_tensors="pt",
41
- ).to(device=model_3n.device, dtype=torch.bfloat16)
42
 
43
- with torch.no_grad():
44
- output_tokens = model_3n.generate(
45
- **inputs,
46
- max_new_tokens=max_new_tokens,
47
- temperature=temperature,
48
- top_p=top_p,
49
- top_k=top_k,
50
- repetition_penalty=repetition_penalty,
51
- do_sample=True,
52
- )
53
 
54
- generated_text = input_processor.batch_decode(
55
- output_tokens[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True
56
- )[0]
57
- return generated_text.strip()
 
 
 
 
 
 
 
 
 
58
 
59
- @spaces.GPU()
60
- def gradio_fn(message, history):
61
- response = infer_text(history + [(message, None)])
62
- return response
63
 
64
- with gr.Blocks(css="""
65
- .gradio-container {
66
- max-width: 600px;
67
- margin: auto;
68
- padding: 20px;
69
- font-family: sans-serif;
70
- position: relative;
71
- }
72
- .chatbot {
73
- height: 500px !important;
74
- overflow-y: auto;
75
- }
76
- .corner {
77
- position: fixed;
78
- bottom: 2px;
79
- z-index: 9999;
80
- pointer-events: none;
81
- }
82
- #left { left: 2px; }
83
- #right { right: 2px; }
84
- .corner img {
85
- height: 500px; /* fixed height */
86
- width: auto; /* auto to keep aspect ratio */
87
- }
88
-
89
- """) as demo:
90
  gr.Markdown(
91
- """
92
- <div style='text-align: center; padding: 10px;'>
93
- <h1 style='font-size: 2.2em; margin-bottom: 0.2em;'>🤖 <span style='color: #4F46E5;'>kRISHNA.ai</span></h1>
94
- <p style='font-size: 1.1em; color: #555;'>5000-Years of Ancient WISDOM with Modern AI ✨</p>
95
- </div>
96
- """,
97
- elem_id="header"
98
  )
99
- chat = gr.ChatInterface(
100
- fn=gradio_fn,
101
- examples=[
102
- "Hello!",
103
- "How can I overcome fear of failure?",
104
- "How do I forgive someone who hurt me deeply?",
105
- "What can I do to stop overthinking?"
106
- ],
107
- chatbot=gr.Chatbot(elem_classes="chatbot"),
108
- theme="compact",
109
  )
110
- gr.HTML(f"""
111
- <div id="left" class="corner">
112
- <img src="https://huggingface.co/spaces/p2kalita/kRISHNA.ai/resolve/main/assets/Arjun.png" alt="Arjun">
113
- </div>
114
- <div id="right" class="corner">
115
- <img src="https://huggingface.co/spaces/p2kalita/kRISHNA.ai/resolve/main/assets/Krishna.png" alt="Krishna">
116
- </div>
117
- """)
118
 
 
 
 
 
 
 
 
 
 
119
 
120
  if __name__ == "__main__":
121
  demo.launch()
 
1
+ import os, torch, gradio as gr, spaces
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
 
 
 
 
 
 
 
 
 
 
3
 
4
+ MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/merged_model")
 
5
 
6
+ # Load once (CPU until first call; device_map will move to GPU on first run)
7
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ MODEL_ID,
10
  device_map="auto",
11
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
12
+ trust_remote_code=True,
13
  )
 
14
 
15
+ def _msgs_from_history(history, system_text):
16
+ msgs = []
17
+ if system_text:
18
+ msgs.append({"role": "system", "content": system_text})
19
+ for user, assistant in history:
20
+ if user:
21
+ msgs.append({"role": "user", "content": user})
22
+ if assistant:
23
+ msgs.append({"role": "assistant", "content": assistant})
24
+ return msgs
25
 
26
+ def _eos_ids(tok):
27
+ ids = {tok.eos_token_id}
28
+ im_end = tok.convert_tokens_to_ids("<|im_end|>")
29
+ if im_end is not None:
30
+ ids.add(im_end)
31
+ return list(ids)
 
32
 
33
+ @spaces.GPU(duration=120) # REQUIRED for ZeroGPU; remove if using standard GPU hardware
34
+ def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
35
+ msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
36
+ prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
37
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
 
 
 
 
 
38
 
39
+ gen_cfg = GenerationConfig(
40
+ do_sample=True,
41
+ temperature=float(temperature),
42
+ top_p=float(top_p),
43
+ max_new_tokens=int(max_new),
44
+ min_new_tokens=int(min_new),
45
+ repetition_penalty=1.02,
46
+ no_repeat_ngram_size=3,
47
+ eos_token_id=_eos_ids(tokenizer),
48
+ pad_token_id=tokenizer.eos_token_id,
49
+ )
50
+ with torch.no_grad():
51
+ out = model.generate(**inputs, generation_config=gen_cfg)
52
 
53
+ # slice off the prompt so we show only the assistant reply
54
+ new_tokens = out[:, inputs["input_ids"].shape[1]:]
55
+ reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
56
+ return reply
57
 
58
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  gr.Markdown(
60
+ "<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B Fine-tuned)</h1>"
61
+ "<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
 
 
 
 
 
62
  )
63
+ system_box = gr.Textbox(
64
+ value="Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.",
65
+ label="System prompt",
 
 
 
 
 
 
 
66
  )
67
+ temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
68
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
69
+ max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
70
+ min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")
 
 
 
 
71
 
72
+ chat = gr.ChatInterface(
73
+ fn=lambda m, h: chat_fn(m, h, system_box.value, temperature.value, top_p.value, max_new.value, min_new.value),
74
+ title=None,
75
+ additional_inputs=[system_box, temperature, top_p, max_new, min_new],
76
+ retry_btn="Regenerate",
77
+ undo_btn="Undo Last",
78
+ clear_btn="Clear",
79
+ queue=True, # queue is recommended (and required for ZeroGPU concurrency)
80
+ )
81
 
82
  if __name__ == "__main__":
83
  demo.launch()