Update Gradio app with multiple files
Browse files
models.py
CHANGED
|
@@ -69,12 +69,18 @@ def stream_generate_response(prompt: str, history: list) -> Generator[str, None,
|
|
| 69 |
for human, bot in history:
|
| 70 |
# Add past exchanges
|
| 71 |
if human:
|
| 72 |
-
messages.append({
|
|
|
|
|
|
|
| 73 |
if bot:
|
| 74 |
-
messages.append({
|
|
|
|
|
|
|
| 75 |
|
| 76 |
# Add the current prompt
|
| 77 |
-
messages.append({
|
|
|
|
|
|
|
| 78 |
|
| 79 |
# Apply chat template
|
| 80 |
text = tokenizer.apply_chat_template(
|
|
@@ -86,45 +92,57 @@ def stream_generate_response(prompt: str, history: list) -> Generator[str, None,
|
|
| 86 |
# Prepare inputs and move to model device
|
| 87 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 88 |
|
| 89 |
-
#
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
#
|
|
|
|
| 97 |
|
|
|
|
| 98 |
input_ids = model_inputs.input_ids
|
| 99 |
|
|
|
|
| 100 |
generated_ids = model.generate(
|
| 101 |
input_ids=input_ids,
|
| 102 |
max_new_tokens=MAX_NEW_TOKENS,
|
| 103 |
do_sample=DO_SAMPLE,
|
| 104 |
temperature=TEMPERATURE,
|
| 105 |
pad_token_id=tokenizer.eos_token_id,
|
| 106 |
-
|
| 107 |
-
output_scores=True,
|
| 108 |
-
min_new_tokens=1,
|
| 109 |
-
# Enable iterative decoding
|
| 110 |
repetition_penalty=1.1,
|
| 111 |
)
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
for
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
# Yield only the difference from the previous chunk
|
| 123 |
-
if len(current_response) > len(full_response):
|
| 124 |
-
new_text = current_response[len(full_response):]
|
| 125 |
-
full_response = current_response
|
| 126 |
-
yield new_text
|
| 127 |
-
|
| 128 |
-
# Final cleanup (sometimes the model output is slightly messy)
|
| 129 |
-
if full_response:
|
| 130 |
-
yield full_response.strip()
|
|
|
|
| 69 |
for human, bot in history:
|
| 70 |
# Add past exchanges
|
| 71 |
if human:
|
| 72 |
+
messages.append({
|
| 73 |
+
"role": "user", "content": human
|
| 74 |
+
})
|
| 75 |
if bot:
|
| 76 |
+
messages.append({
|
| 77 |
+
"role": "assistant", "content": bot
|
| 78 |
+
})
|
| 79 |
|
| 80 |
# Add the current prompt
|
| 81 |
+
messages.append({
|
| 82 |
+
"role": "user", "content": prompt
|
| 83 |
+
})
|
| 84 |
|
| 85 |
# Apply chat template
|
| 86 |
text = tokenizer.apply_chat_template(
|
|
|
|
| 92 |
# Prepare inputs and move to model device
|
| 93 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 94 |
|
| 95 |
+
# Create a custom streamer that works with Gradio
|
| 96 |
+
class GradioStreamer:
|
| 97 |
+
def __init__(self, tokenizer):
|
| 98 |
+
self.tokenizer = tokenizer
|
| 99 |
+
self.text_queue = []
|
| 100 |
+
self.generated_text = ""
|
| 101 |
+
|
| 102 |
+
def put(self, value):
|
| 103 |
+
# Decode the new tokens and add to queue
|
| 104 |
+
if isinstance(value, torch.Tensor):
|
| 105 |
+
new_text = self.tokenizer.decode(value, skip_special_tokens=True)
|
| 106 |
+
# Only yield the new part
|
| 107 |
+
if new_text.startswith(self.generated_text):
|
| 108 |
+
new_part = new_text[len(self.generated_text):]
|
| 109 |
+
if new_part:
|
| 110 |
+
self.text_queue.append(new_part)
|
| 111 |
+
self.generated_text = new_text
|
| 112 |
+
else:
|
| 113 |
+
# Sometimes the decoding might not align perfectly
|
| 114 |
+
self.text_queue.append(new_text)
|
| 115 |
+
self.generated_text = new_text
|
| 116 |
+
|
| 117 |
+
def end(self):
|
| 118 |
+
pass
|
| 119 |
+
|
| 120 |
+
def __iter__(self):
|
| 121 |
+
return iter(self.text_queue)
|
| 122 |
|
| 123 |
+
# Create our custom streamer
|
| 124 |
+
gradio_streamer = GradioStreamer(tokenizer)
|
| 125 |
|
| 126 |
+
# Generate with streaming
|
| 127 |
input_ids = model_inputs.input_ids
|
| 128 |
|
| 129 |
+
# Generate tokens one by one for true streaming
|
| 130 |
generated_ids = model.generate(
|
| 131 |
input_ids=input_ids,
|
| 132 |
max_new_tokens=MAX_NEW_TOKENS,
|
| 133 |
do_sample=DO_SAMPLE,
|
| 134 |
temperature=TEMPERATURE,
|
| 135 |
pad_token_id=tokenizer.eos_token_id,
|
| 136 |
+
streamer=gradio_streamer,
|
|
|
|
|
|
|
|
|
|
| 137 |
repetition_penalty=1.1,
|
| 138 |
)
|
| 139 |
+
|
| 140 |
+
# Yield the text as it's generated
|
| 141 |
+
accumulated_text = ""
|
| 142 |
+
for new_chunk in gradio_streamer.text_queue:
|
| 143 |
+
accumulated_text += new_chunk
|
| 144 |
+
yield accumulated_text
|
| 145 |
+
|
| 146 |
+
# Final yield to ensure complete text is sent
|
| 147 |
+
if accumulated_text:
|
| 148 |
+
yield accumulated_text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|