JDhruv14 commited on
Commit
e8c693f
·
verified ·
1 Parent(s): e51e513

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -76
app.py CHANGED
@@ -1,94 +1,121 @@
1
- import os, torch, gradio as gr, spaces
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3
- from peft import PeftModel
4
-
5
- # ---- IDs (can override from Space Secrets) ----
6
- BASE_ID = os.getenv("BASE_ID", "Qwen/Qwen2.5-3B-Instruct")
7
- ADAPTER_ID = os.getenv("ADAPTER_ID", "JDhruv14/Gita-FT-v2-Qwen2.5-3B")
 
 
 
 
 
8
 
9
- # ---- Load tokenizer & base model ----
10
- tokenizer = AutoTokenizer.from_pretrained(BASE_ID, trust_remote_code=True)
11
 
12
- model = AutoModelForCausalLM.from_pretrained(
13
- BASE_ID,
 
 
14
  device_map="auto",
15
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
16
- trust_remote_code=True,
17
  )
18
- # Apply LoRA adapter
19
- model = PeftModel.from_pretrained(model, ADAPTER_ID)
20
- model.eval()
21
-
22
- def _eos_ids(tok):
23
- ids = {tok.eos_token_id}
24
- im_end = tok.convert_tokens_to_ids("<|im_end|>")
25
- if im_end is not None:
26
- ids.add(im_end)
27
- return list(ids)
28
 
29
- def _format_history(history, system_text):
30
- msgs = []
31
- if system_text:
32
- msgs.append({"role": "system", "content": system_text})
33
- for user, assistant in history:
34
- if user:
35
- msgs.append({"role": "user", "content": user})
36
- if assistant:
37
- msgs.append({"role": "assistant", "content": assistant})
38
- return msgs
39
 
40
- @spaces.GPU(duration=120) # keep for ZeroGPU; remove this decorator if using a normal GPU Space
41
- def chat_fn(message, history, system_text, temperature, top_p, max_new_tokens, min_new_tokens):
42
- msgs = _format_history(history, system_text) + [{"role": "user", "content": message}]
43
- prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
44
- inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
 
 
45
 
46
- gen_cfg = GenerationConfig(
47
- do_sample=True,
48
- temperature=float(temperature),
49
- top_p=float(top_p),
50
- max_new_tokens=int(max_new_tokens),
51
- min_new_tokens=int(min_new_tokens),
52
- repetition_penalty=1.02,
53
- no_repeat_ngram_size=3,
54
- eos_token_id=_eos_ids(tokenizer),
55
- pad_token_id=tokenizer.eos_token_id,
56
- )
57
  with torch.no_grad():
58
- outputs = model.generate(**inputs, generation_config=gen_cfg)
 
 
 
 
 
 
 
 
59
 
60
- # show only the assistant reply (slice off the prompt)
61
- new_tokens = outputs[:, inputs["input_ids"].shape[1]:]
62
- reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
63
- return reply
64
 
65
- with gr.Blocks() as demo:
66
- gr.Markdown(
67
- "<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B + LoRA)</h1>"
68
- "<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
69
- )
70
 
71
- system_box = gr.Textbox(
72
- value="Reply in the user’s language with 2–3 concrete points (200–400 words); cite Gita verses when relevant.",
73
- label="System prompt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
75
- temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
76
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
77
- max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
78
- min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")
79
-
80
- gr.ChatInterface(
81
- fn=chat_fn, # def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)
82
- additional_inputs=[system_box, temperature, top_p, max_new, min_new],
83
- chatbot=gr.Chatbot(height=520, type="tuples"), # keep tuple history; no behavior change
84
  examples=[
85
- ["How do I practice Nishkama Karma at work?", system_box.value, 0.7, 0.9, 512, 160],
86
- ["What does 3.19 teach about duty without attachment?", system_box.value, 0.7, 0.9, 512, 160],
87
- ["How to overcome fear of failure according to the Gita?", system_box.value, 0.7, 0.9, 512, 160],
 
88
  ],
 
 
89
  )
 
 
 
 
 
 
 
 
 
90
 
91
  if __name__ == "__main__":
92
  demo.launch()
93
-
94
-
 
1
+ import torch
2
+ torch._dynamo.config.disable = True
3
+ from collections.abc import Iterator
4
+ from transformers import (
5
+ Gemma3ForConditionalGeneration,
6
+ TextIteratorStreamer,
7
+ Gemma3Processor,
8
+ Gemma3nForConditionalGeneration,
9
+ )
10
+ import gradio as gr
11
+ import os
12
+ import spaces
13
 
14
+ # Load environment variables
15
+ model_3n_id = os.getenv("MODEL_3N_ID", "JDhruv14/merged_model")
16
 
17
+ # Load model and processor
18
+ model_3n = Gemma3nForConditionalGeneration.from_pretrained(
19
+ model_3n_id,
20
+ dtype=torch.bfloat16,
21
  device_map="auto",
22
+ attn_implementation="eager"
 
23
  )
24
+ input_processor = Gemma3Processor.from_pretrained(model_3n_id)
 
 
 
 
 
 
 
 
 
25
 
26
+ def infer_text(messages, max_new_tokens=300, temperature=1.0, top_p=0.95, top_k=64, repetition_penalty=1.1):
27
+ chat_template = []
28
+ for turn in messages:
29
+ if turn[0]:
30
+ chat_template.append({"role": "user", "content": [{"type": "text", "text": turn[0]}]})
31
+ if turn[1]:
32
+ chat_template.append({"role": "assistant", "content": [{"type": "text", "text": turn[1]}]})
33
+ chat_template.append({"role": "assistant", "content": [{"type": "text", "text": ""}]})
 
 
34
 
35
+ inputs = input_processor.apply_chat_template(
36
+ chat_template,
37
+ add_generation_prompt=True,
38
+ tokenize=True,
39
+ return_dict=True,
40
+ return_tensors="pt",
41
+ ).to(device=model_3n.device, dtype=torch.bfloat16)
42
 
 
 
 
 
 
 
 
 
 
 
 
43
  with torch.no_grad():
44
+ output_tokens = model_3n.generate(
45
+ **inputs,
46
+ max_new_tokens=max_new_tokens,
47
+ temperature=temperature,
48
+ top_p=top_p,
49
+ top_k=top_k,
50
+ repetition_penalty=repetition_penalty,
51
+ do_sample=True,
52
+ )
53
 
54
+ generated_text = input_processor.batch_decode(
55
+ output_tokens[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True
56
+ )[0]
57
+ return generated_text.strip()
58
 
59
+ @spaces.GPU()
60
+ def gradio_fn(message, history):
61
+ response = infer_text(history + [(message, None)])
62
+ return response
 
63
 
64
+ with gr.Blocks(css="""
65
+ .gradio-container {
66
+ max-width: 600px;
67
+ margin: auto;
68
+ padding: 20px;
69
+ font-family: sans-serif;
70
+ position: relative;
71
+ }
72
+ .chatbot {
73
+ height: 500px !important;
74
+ overflow-y: auto;
75
+ }
76
+ .corner {
77
+ position: fixed;
78
+ bottom: 2px;
79
+ z-index: 9999;
80
+ pointer-events: none;
81
+ }
82
+ #left { left: 2px; }
83
+ #right { right: 2px; }
84
+ .corner img {
85
+ height: 500px; /* fixed height */
86
+ width: auto; /* auto to keep aspect ratio */
87
+ }
88
+
89
+ """) as demo:
90
+ gr.Markdown(
91
+ """
92
+ <div style='text-align: center; padding: 10px;'>
93
+ <h1 style='font-size: 2.2em; margin-bottom: 0.2em;'>🤖 <span style='color: #4F46E5;'>kRISHNA.ai</span></h1>
94
+ <p style='font-size: 1.1em; color: #555;'>5000-Years of Ancient WISDOM with Modern AI ✨</p>
95
+ </div>
96
+ """,
97
+ elem_id="header"
98
  )
99
+ chat = gr.ChatInterface(
100
+ fn=gradio_fn,
 
 
 
 
 
 
 
101
  examples=[
102
+ "Hello!",
103
+ "How can I overcome fear of failure?",
104
+ "How do I forgive someone who hurt me deeply?",
105
+ "What can I do to stop overthinking?"
106
  ],
107
+ chatbot=gr.Chatbot(elem_classes="chatbot"),
108
+ theme="compact",
109
  )
110
+ gr.HTML(f"""
111
+ <div id="left" class="corner">
112
+ <img src="https://huggingface.co/spaces/p2kalita/kRISHNA.ai/resolve/main/assets/Arjun.png" alt="Arjun">
113
+ </div>
114
+ <div id="right" class="corner">
115
+ <img src="https://huggingface.co/spaces/p2kalita/kRISHNA.ai/resolve/main/assets/Krishna.png" alt="Krishna">
116
+ </div>
117
+ """)
118
+
119
 
120
  if __name__ == "__main__":
121
  demo.launch()