akhaliq HF Staff commited on
Commit
eae8d97
·
verified ·
1 Parent(s): af9ed97

Update Gradio app with multiple files

Browse files
Files changed (1) hide show
  1. models.py +51 -33
models.py CHANGED
@@ -69,12 +69,18 @@ def stream_generate_response(prompt: str, history: list) -> Generator[str, None,
69
  for human, bot in history:
70
  # Add past exchanges
71
  if human:
72
- messages.append({"role": "user", "content": human})
 
 
73
  if bot:
74
- messages.append({"role": "assistant", "content": bot})
 
 
75
 
76
  # Add the current prompt
77
- messages.append({"role": "user", "content": prompt})
 
 
78
 
79
  # Apply chat template
80
  text = tokenizer.apply_chat_template(
@@ -86,45 +92,57 @@ def stream_generate_response(prompt: str, history: list) -> Generator[str, None,
86
  # Prepare inputs and move to model device
87
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
88
 
89
- # Use TextStreamer for efficient token streaming
90
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
91
-
92
- # Start generation in a separate thread (TextStreamer uses an internal blocking mechanism)
93
- # Since Gradio's generator interface expects synchronous yields from the main thread
94
- # within the @spaces.GPU context, we need to adapt the TextStreamer output.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- # A cleaner approach for Gradio streaming is direct model generation without TextStreamer:
 
97
 
 
98
  input_ids = model_inputs.input_ids
99
 
 
100
  generated_ids = model.generate(
101
  input_ids=input_ids,
102
  max_new_tokens=MAX_NEW_TOKENS,
103
  do_sample=DO_SAMPLE,
104
  temperature=TEMPERATURE,
105
  pad_token_id=tokenizer.eos_token_id,
106
- return_dict_in_generate=True,
107
- output_scores=True,
108
- min_new_tokens=1,
109
- # Enable iterative decoding
110
  repetition_penalty=1.1,
111
  )
112
-
113
- full_response = ""
114
- # Process output sequence token by token
115
- for seq in generated_ids.sequences:
116
- # Get the new tokens generated after the prompt
117
- new_tokens = seq[input_ids.shape[-1]:]
118
-
119
- # Decode only the newly generated part of the sequence so far
120
- current_response = tokenizer.decode(new_tokens, skip_special_tokens=True)
121
-
122
- # Yield only the difference from the previous chunk
123
- if len(current_response) > len(full_response):
124
- new_text = current_response[len(full_response):]
125
- full_response = current_response
126
- yield new_text
127
-
128
- # Final cleanup (sometimes the model output is slightly messy)
129
- if full_response:
130
- yield full_response.strip()
 
69
  for human, bot in history:
70
  # Add past exchanges
71
  if human:
72
+ messages.append({
73
+ "role": "user", "content": human
74
+ })
75
  if bot:
76
+ messages.append({
77
+ "role": "assistant", "content": bot
78
+ })
79
 
80
  # Add the current prompt
81
+ messages.append({
82
+ "role": "user", "content": prompt
83
+ })
84
 
85
  # Apply chat template
86
  text = tokenizer.apply_chat_template(
 
92
  # Prepare inputs and move to model device
93
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
94
 
95
+ # Create a custom streamer that works with Gradio
96
+ class GradioStreamer:
97
+ def __init__(self, tokenizer):
98
+ self.tokenizer = tokenizer
99
+ self.text_queue = []
100
+ self.generated_text = ""
101
+
102
+ def put(self, value):
103
+ # Decode the new tokens and add to queue
104
+ if isinstance(value, torch.Tensor):
105
+ new_text = self.tokenizer.decode(value, skip_special_tokens=True)
106
+ # Only yield the new part
107
+ if new_text.startswith(self.generated_text):
108
+ new_part = new_text[len(self.generated_text):]
109
+ if new_part:
110
+ self.text_queue.append(new_part)
111
+ self.generated_text = new_text
112
+ else:
113
+ # Sometimes the decoding might not align perfectly
114
+ self.text_queue.append(new_text)
115
+ self.generated_text = new_text
116
+
117
+ def end(self):
118
+ pass
119
+
120
+ def __iter__(self):
121
+ return iter(self.text_queue)
122
 
123
+ # Create our custom streamer
124
+ gradio_streamer = GradioStreamer(tokenizer)
125
 
126
+ # Generate with streaming
127
  input_ids = model_inputs.input_ids
128
 
129
+ # Generate tokens one by one for true streaming
130
  generated_ids = model.generate(
131
  input_ids=input_ids,
132
  max_new_tokens=MAX_NEW_TOKENS,
133
  do_sample=DO_SAMPLE,
134
  temperature=TEMPERATURE,
135
  pad_token_id=tokenizer.eos_token_id,
136
+ streamer=gradio_streamer,
 
 
 
137
  repetition_penalty=1.1,
138
  )
139
+
140
+ # Yield the text as it's generated
141
+ accumulated_text = ""
142
+ for new_chunk in gradio_streamer.text_queue:
143
+ accumulated_text += new_chunk
144
+ yield accumulated_text
145
+
146
+ # Final yield to ensure complete text is sent
147
+ if accumulated_text:
148
+ yield accumulated_text.strip()