JDhruv14 commited on
Commit
9a2d448
·
verified ·
1 Parent(s): 52ab581

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -3,6 +3,14 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3
 
4
  MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/merged_model")
5
 
 
 
 
 
 
 
 
 
6
  # Load once (CPU until first call; device_map will move to GPU on first run)
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
8
  model = AutoModelForCausalLM.from_pretrained(
@@ -20,11 +28,21 @@ def _msgs_from_history(history, system_text):
20
  msgs = []
21
  if system_text:
22
  msgs.append({"role": "system", "content": system_text})
23
- for user, assistant in history:
24
- if user:
25
- msgs.append({"role": "user", "content": user})
26
- if assistant:
27
- msgs.append({"role": "assistant", "content": assistant})
 
 
 
 
 
 
 
 
 
 
28
  return msgs
29
 
30
  def _eos_ids(tok):
@@ -41,7 +59,6 @@ def _eos_ids(tok):
41
  ids.add(im_end)
42
  except Exception:
43
  pass
44
- # Fallback: if still empty, just skip setting eos_token_id in GenerationConfig
45
  return list(ids)
46
 
47
  def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
@@ -68,18 +85,17 @@ def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)
68
  with torch.no_grad():
69
  out = model.generate(**inputs, generation_config=gen_cfg)
70
 
71
- # slice off the prompt so we show only the assistant reply
72
  new_tokens = out[:, inputs["input_ids"].shape[1]:]
73
  reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
74
  return reply
75
 
76
  @spaces.GPU()
77
  def gradio_fn(message, history):
78
- # Minimal fix: call the defined chat_fn with sensible defaults
79
  return chat_fn(
80
  message=message,
81
  history=history,
82
- system_text="",
83
  temperature=0.7,
84
  top_p=0.95,
85
  max_new=512,
@@ -115,7 +131,7 @@ with gr.Blocks(css="""
115
  gr.Markdown(
116
  """
117
  <div style='text-align: center; padding: 10px;'>
118
- <h1 style='font-size: 2.2em; margin-bottom: 0.2em;'>🤖 <span style='color: #4F46E5;'>kRISHNA.ai</span></h1>
119
  <p style='font-size: 1.1em; color: #555;'>5000-Years of Ancient WISDOM with Modern AI ✨</p>
120
  </div>
121
  """,
@@ -129,8 +145,8 @@ with gr.Blocks(css="""
129
  "How do I forgive someone who hurt me deeply?",
130
  "What can I do to stop overthinking?"
131
  ],
132
- chatbot=gr.Chatbot(elem_classes="chatbot"),
133
- theme="compact",
134
  )
135
  gr.HTML(f"""
136
  <div id="left" class="corner">
@@ -141,6 +157,5 @@ with gr.Blocks(css="""
141
  </div>
142
  """)
143
 
144
-
145
  if __name__ == "__main__":
146
  demo.launch()
 
3
 
4
  MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/merged_model")
5
 
6
+ # --- System prompt (Gita persona) ---
7
+ GITA_SYSTEM_PROMPT = """You are KRISHNA.ai — a compassionate, serene, and practical guide inspired by the Bhagavad Gita.
8
+ Style: calm, clear, inclusive, and down-to-earth. Use everyday language, avoid jargon.
9
+ When fitting, quote a brief shloka with Chapter:Verse (e.g., 2:47) and give a one-line meaning. Do not over-quote.
10
+ Emphasize: selfless action (karma-yoga), equanimity, disciplined mind, devotion, and wisdom — applicable to modern life.
11
+ Be non-sectarian and respectful of all beliefs. If a topic is clinical/medical/legal, gently suggest professional help.
12
+ Prefer concise replies (5–10 sentences). Use short steps/bullets for “how-to” answers. End with a one-line “Essence:” summary when helpful."""
13
+
14
  # Load once (CPU until first call; device_map will move to GPU on first run)
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
16
  model = AutoModelForCausalLM.from_pretrained(
 
28
  msgs = []
29
  if system_text:
30
  msgs.append({"role": "system", "content": system_text})
31
+ if not history:
32
+ return msgs
33
+
34
+ # Support both new "messages" format and legacy (user, assistant) tuples
35
+ if isinstance(history[0], dict) and "role" in history[0] and "content" in history[0]:
36
+ for m in history:
37
+ role, content = m.get("role"), m.get("content")
38
+ if role in ("user", "assistant", "system") and content:
39
+ msgs.append({"role": role, "content": content})
40
+ else:
41
+ for user, assistant in history:
42
+ if user:
43
+ msgs.append({"role": "user", "content": user})
44
+ if assistant:
45
+ msgs.append({"role": "assistant", "content": assistant})
46
  return msgs
47
 
48
  def _eos_ids(tok):
 
59
  ids.add(im_end)
60
  except Exception:
61
  pass
 
62
  return list(ids)
63
 
64
  def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
 
85
  with torch.no_grad():
86
  out = model.generate(**inputs, generation_config=gen_cfg)
87
 
 
88
  new_tokens = out[:, inputs["input_ids"].shape[1]:]
89
  reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
90
  return reply
91
 
92
  @spaces.GPU()
93
  def gradio_fn(message, history):
94
+ # Inject the Gita system prompt here
95
  return chat_fn(
96
  message=message,
97
  history=history,
98
+ system_text=GITA_SYSTEM_PROMPT,
99
  temperature=0.7,
100
  top_p=0.95,
101
  max_new=512,
 
131
  gr.Markdown(
132
  """
133
  <div style='text-align: center; padding: 10px;'>
134
+ <h1 style='font-size: 2.2em; margin-bottom: 0.2em;'><span style='color: #4F46E5;'>kRISHNA.ai</span></h1>
135
  <p style='font-size: 1.1em; color: #555;'>5000-Years of Ancient WISDOM with Modern AI ✨</p>
136
  </div>
137
  """,
 
145
  "How do I forgive someone who hurt me deeply?",
146
  "What can I do to stop overthinking?"
147
  ],
148
+ chatbot=gr.Chatbot(type="messages", elem_classes="chatbot"),
149
+ type="messages",
150
  )
151
  gr.HTML(f"""
152
  <div id="left" class="corner">
 
157
  </div>
158
  """)
159
 
 
160
  if __name__ == "__main__":
161
  demo.launch()