JDhruv14 commited on
Commit
c235810
·
verified ·
1 Parent(s): 497537b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -63,6 +63,16 @@ def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)
63
  if eos:
64
  gen_cfg_kwargs["eos_token_id"] = eos
65
 
 
 
 
 
 
 
 
 
 
 
66
  def infer_text(history, system_text=""):
67
  """
68
  Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.
@@ -85,16 +95,6 @@ def infer_text(history, system_text=""):
85
  min_new=128,
86
  )
87
 
88
- gen_cfg = GenerationConfig(**gen_cfg_kwargs)
89
-
90
- with torch.no_grad():
91
- out = model.generate(**inputs, generation_config=gen_cfg)
92
-
93
- # slice off the prompt so we show only the assistant reply
94
- new_tokens = out[:, inputs["input_ids"].shape[1]:]
95
- reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
96
- return reply
97
-
98
  @spaces.GPU()
99
  def gradio_fn(message, history):
100
  response = infer_text(history + [(message, None)])
@@ -155,7 +155,5 @@ with gr.Blocks(css="""
155
  </div>
156
  """)
157
 
158
-
159
  if __name__ == "__main__":
160
  demo.launch()
161
-
 
63
  if eos:
64
  gen_cfg_kwargs["eos_token_id"] = eos
65
 
66
+ gen_cfg = GenerationConfig(**gen_cfg_kwargs)
67
+
68
+ with torch.no_grad():
69
+ out = model.generate(**inputs, generation_config=gen_cfg)
70
+
71
+ # slice off the prompt so we show only the assistant reply
72
+ new_tokens = out[:, inputs["input_ids"].shape[1]:]
73
+ reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
74
+ return reply
75
+
76
  def infer_text(history, system_text=""):
77
  """
78
  Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.
 
95
  min_new=128,
96
  )
97
 
 
 
 
 
 
 
 
 
 
 
98
  @spaces.GPU()
99
  def gradio_fn(message, history):
100
  response = infer_text(history + [(message, None)])
 
155
  </div>
156
  """)
157
 
 
158
  if __name__ == "__main__":
159
  demo.launch()