RamishRasool14 commited on
Commit
1da696e
·
1 Parent(s): 2ca9a16
Files changed (1) hide show
  1. app.py +74 -27
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from unsloth import FastLanguageModel
3
- from transformers import TextStreamer
4
- import torch
5
 
6
  # Load your fine-tuned model and tokenizer
7
  model, tokenizer = FastLanguageModel.from_pretrained(
@@ -10,29 +10,62 @@ model, tokenizer = FastLanguageModel.from_pretrained(
10
  dtype='bf16',
11
  load_in_4bit=False,
12
  )
13
- text_streamer = TextStreamer(tokenizer)
14
  FastLanguageModel.for_inference(model) # Enable optimized inference
15
 
16
- def predict(question, history):
17
- history = history or []
18
- history.append({"from": "human", "value": question})
 
 
 
 
 
19
 
20
  inputs = tokenizer(
21
- tokenizer.apply_chat_template(
22
- history,
23
- tokenize=False,
24
- add_generation_prompt=True
25
- ),
26
- return_tensors="pt"
27
- ).to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- token = model.generate(**inputs, max_new_tokens=8192, streamer = text_streamer)
30
- output = tokenizer.decode(token, skip_special_tokens=True)
31
- output = output.split("[/INST]")[-1].strip()
32
- yield output
 
 
 
33
 
34
- history.append({"from": "gpt", "value": output})
35
- return history
 
 
 
 
 
 
 
 
36
 
37
  # Create the Gradio interface with Markdown support
38
  with gr.Blocks(css=".message { white-space: pre-wrap; }") as iface:
@@ -41,30 +74,44 @@ with gr.Blocks(css=".message { white-space: pre-wrap; }") as iface:
41
  container=True,
42
  height=600,
43
  bubble_full_width=False,
44
- render_markdown=True, # Enable markdown rendering
45
- latex_delimiters=[ # Optional: Enable LaTeX rendering
46
  {"left": "$$", "right": "$$", "display": True},
47
  {"left": "$", "right": "$", "display": False},
48
  ],
49
  )
50
- submit = gr.Button("Submit")
51
  msg = gr.Textbox(
52
  label="Message",
53
  placeholder="Type your message here... (Markdown supported)",
54
  lines=2
55
  )
 
56
  clear = gr.Button("Clear")
57
 
58
- state = gr.State([])
 
 
 
 
 
 
 
 
59
 
60
  submit.click(
61
  predict,
62
- [msg, state],
63
- [chatbot, state],
64
- api_name="predict"
 
65
  )
66
 
67
- clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
  iface.launch()
 
1
  import gradio as gr
2
  from unsloth import FastLanguageModel
3
+ from transformers import TextIteratorStreamer
4
+ from threading import Thread
5
 
6
  # Load your fine-tuned model and tokenizer
7
  model, tokenizer = FastLanguageModel.from_pretrained(
 
10
  dtype='bf16',
11
  load_in_4bit=False,
12
  )
 
13
  FastLanguageModel.for_inference(model) # Enable optimized inference
14
 
15
+ def get_streaming_generator(model, tokenizer, history, max_new_tokens=8192):
16
+ """Function that returns a generator yielding streaming outputs"""
17
+ # Convert history to the format expected by tokenizer
18
+ formatted_history = []
19
+ for exchange in history:
20
+ formatted_history.append({"role": "user", "content": exchange[0]})
21
+ if len(exchange) > 1 and exchange[1]:
22
+ formatted_history.append({"role": "assistant", "content": exchange[1]})
23
 
24
  inputs = tokenizer(
25
+ [
26
+ tokenizer.apply_chat_template(formatted_history,
27
+ tokenize=False,
28
+ add_generation_prompt=True),
29
+ ],
30
+ return_tensors="pt",
31
+ padding=True,
32
+ return_attention_mask=True
33
+ ).to("cuda")
34
+
35
+ # Create the streamer
36
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
37
+
38
+ # Run generation in a separate thread
39
+ generation_kwargs = dict(
40
+ input_ids=inputs["input_ids"],
41
+ attention_mask=inputs["attention_mask"],
42
+ streamer=streamer,
43
+ max_new_tokens=max_new_tokens
44
+ )
45
+
46
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
47
+ thread.start()
48
+
49
+ return streamer
50
 
51
+ def predict(message, history):
52
+ # Add user message to history in the format Gradio expects
53
+ history = history or []
54
+ history.append([message, ""])
55
+
56
+ # Get the streamer with properly formatted history
57
+ streamer = get_streaming_generator(model, tokenizer, history)
58
 
59
+ # Stream the response
60
+ full_response = ""
61
+ for text_chunk in streamer:
62
+ full_response += text_chunk
63
+ # Update the last message with the current full response
64
+ history[-1][1] = full_response
65
+ yield history
66
+
67
+ def clear_chat():
68
+ return [], ""
69
 
70
  # Create the Gradio interface with Markdown support
71
  with gr.Blocks(css=".message { white-space: pre-wrap; }") as iface:
 
74
  container=True,
75
  height=600,
76
  bubble_full_width=False,
77
+ render_markdown=True,
78
+ latex_delimiters=[
79
  {"left": "$$", "right": "$$", "display": True},
80
  {"left": "$", "right": "$", "display": False},
81
  ],
82
  )
 
83
  msg = gr.Textbox(
84
  label="Message",
85
  placeholder="Type your message here... (Markdown supported)",
86
  lines=2
87
  )
88
+ submit = gr.Button("Submit")
89
  clear = gr.Button("Clear")
90
 
91
+ # Set up the chat interface with streaming
92
+ msg.submit(
93
+ predict,
94
+ [msg, chatbot],
95
+ [chatbot],
96
+ api_name="predict"
97
+ ).then(
98
+ lambda: "", None, [msg] # Clear input after submission
99
+ )
100
 
101
  submit.click(
102
  predict,
103
+ [msg, chatbot],
104
+ [chatbot]
105
+ ).then(
106
+ lambda: "", None, [msg] # Clear input after submission
107
  )
108
 
109
+ clear.click(
110
+ clear_chat,
111
+ None,
112
+ [chatbot, msg],
113
+ queue=False
114
+ )
115
 
116
  if __name__ == "__main__":
117
  iface.launch()