krish10 commited on
Commit
4a4eadb
·
verified ·
1 Parent(s): 67e0693

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -27,9 +27,9 @@ def stream_response(title_text, abstract_text):
27
 
28
  generated = inputs["input_ids"]
29
  past_key_values = None
30
- response_text = ""
31
 
32
- for _ in range(1024): # limit max length
 
33
  outputs = model(input_ids=generated, past_key_values=past_key_values, use_cache=True)
34
  next_token_logits = outputs.logits[:, -1, :]
35
  next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
@@ -41,28 +41,33 @@ def stream_response(title_text, abstract_text):
41
  decoded_output = tokenizer.decode(generated[0], skip_special_tokens=True)
42
 
43
  if "<|assistant|>" in decoded_output:
44
- response_text = decoded_output.split("<|assistant|>")[-1].strip()
45
- yield response_text
46
 
47
  with gr.Blocks() as demo:
48
- gr.Markdown("## 🧠 RCT Classifier Demonstration (Streaming Enabled)")
 
 
49
 
50
- chatbot = gr.Chatbot()
51
  with gr.Row():
52
  title = gr.Textbox(label="Title", placeholder="Enter title")
53
  abstract = gr.Textbox(label="Abstract", placeholder="Enter abstract", lines=6)
 
54
  submit = gr.Button("Classify")
55
 
56
  def stream_chat(title_text, abstract_text):
57
- user_message = f"Title: {title_text.strip()}\nAbstract: {abstract_text.strip()}"
58
- yield (user_message, "") # show user message
59
- response_stream = stream_response(title_text, abstract_text)
60
- collected = ""
61
- for partial in response_stream:
62
- collected = partial
63
- yield (user_message, collected)
64
-
65
- submit.stream(fn=stream_chat, inputs=[title, abstract], outputs=chatbot)
 
 
 
66
 
67
  if __name__ == "__main__":
68
  demo.launch()
 
27
 
28
  generated = inputs["input_ids"]
29
  past_key_values = None
 
30
 
31
+ reply = ""
32
+ for _ in range(1024):
33
  outputs = model(input_ids=generated, past_key_values=past_key_values, use_cache=True)
34
  next_token_logits = outputs.logits[:, -1, :]
35
  next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
 
41
  decoded_output = tokenizer.decode(generated[0], skip_special_tokens=True)
42
 
43
  if "<|assistant|>" in decoded_output:
44
+ reply = decoded_output.split("<|assistant|>")[-1].strip()
45
+ yield [(user_input, reply)]
46
 
47
  with gr.Blocks() as demo:
48
+ gr.Markdown("## 🧠 RCT Classifier Demonstration (Streaming, Stateless)")
49
+
50
+ chatbot = gr.Chatbot(label="Model Response", type="tuple")
51
 
 
52
  with gr.Row():
53
  title = gr.Textbox(label="Title", placeholder="Enter title")
54
  abstract = gr.Textbox(label="Abstract", placeholder="Enter abstract", lines=6)
55
+
56
  submit = gr.Button("Classify")
57
 
58
  def stream_chat(title_text, abstract_text):
59
+ user_input = f"Title: {title_text.strip()}\nAbstract: {abstract_text.strip()}"
60
+ yield [(user_input, "")] # show initial empty response
61
+
62
+ response_generator = stream_response(title_text, abstract_text)
63
+ for update in response_generator:
64
+ yield update
65
+
66
+ submit.click(
67
+ fn=stream_chat,
68
+ inputs=[title, abstract],
69
+ outputs=chatbot
70
+ )
71
 
72
  if __name__ == "__main__":
73
  demo.launch()