Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,6 +5,8 @@ import threading
|
|
| 5 |
import queue
|
| 6 |
import time
|
| 7 |
import spaces
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# Model configuration
|
| 10 |
model_name = "HelpingAI/Dhanishtha-2.0-preview"
|
|
@@ -30,30 +32,30 @@ def load_model():
|
|
| 30 |
|
| 31 |
print("Model loaded successfully!")
|
| 32 |
|
| 33 |
-
class
|
| 34 |
-
"""
|
| 35 |
-
def __init__(self
|
| 36 |
-
# TextStreamer only accepts tokenizer and skip_prompt parameters
|
| 37 |
-
super().__init__(tokenizer, skip_prompt)
|
| 38 |
self.text_queue = queue.Queue()
|
| 39 |
-
self.
|
| 40 |
-
self.skip_special_tokens = True # Handle this manually if needed
|
| 41 |
-
|
| 42 |
-
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 43 |
-
"""Called when text is finalized"""
|
| 44 |
-
self.generated_text += text
|
| 45 |
-
self.text_queue.put(text)
|
| 46 |
-
if stream_end:
|
| 47 |
-
self.text_queue.put(None)
|
| 48 |
-
|
| 49 |
-
def get_generated_text(self):
|
| 50 |
-
"""Get all generated text so far"""
|
| 51 |
-
return self.generated_text
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def reset(self):
|
| 54 |
-
"""Reset the
|
| 55 |
-
self.
|
| 56 |
-
# Clear the queue
|
| 57 |
while not self.text_queue.empty():
|
| 58 |
try:
|
| 59 |
self.text_queue.get_nowait()
|
|
@@ -89,11 +91,16 @@ def generate_response(message, history, max_tokens, temperature, top_p):
|
|
| 89 |
# Tokenize input
|
| 90 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 91 |
|
| 92 |
-
# Create
|
| 93 |
-
|
| 94 |
-
streamer.reset()
|
| 95 |
|
| 96 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
generation_kwargs = {
|
| 98 |
**model_inputs,
|
| 99 |
"max_new_tokens": max_tokens,
|
|
@@ -102,17 +109,21 @@ def generate_response(message, history, max_tokens, temperature, top_p):
|
|
| 102 |
"do_sample": True,
|
| 103 |
"pad_token_id": tokenizer.eos_token_id,
|
| 104 |
"streamer": streamer,
|
| 105 |
-
"return_dict_in_generate": True
|
| 106 |
}
|
| 107 |
|
| 108 |
-
#
|
| 109 |
def generate():
|
| 110 |
try:
|
|
|
|
|
|
|
| 111 |
with torch.no_grad():
|
| 112 |
model.generate(**generation_kwargs)
|
| 113 |
except Exception as e:
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
thread = threading.Thread(target=generate)
|
| 118 |
thread.start()
|
|
@@ -121,7 +132,7 @@ def generate_response(message, history, max_tokens, temperature, top_p):
|
|
| 121 |
generated_text = ""
|
| 122 |
while True:
|
| 123 |
try:
|
| 124 |
-
new_text =
|
| 125 |
if new_text is None:
|
| 126 |
break
|
| 127 |
generated_text += new_text
|
|
|
|
| 5 |
import queue
|
| 6 |
import time
|
| 7 |
import spaces
|
| 8 |
+
import sys
|
| 9 |
+
from io import StringIO
|
| 10 |
|
| 11 |
# Model configuration
|
| 12 |
model_name = "HelpingAI/Dhanishtha-2.0-preview"
|
|
|
|
| 32 |
|
| 33 |
print("Model loaded successfully!")
|
| 34 |
|
| 35 |
+
class StreamCapture:
|
| 36 |
+
"""Capture streaming output from TextStreamer"""
|
| 37 |
+
def __init__(self):
|
|
|
|
|
|
|
| 38 |
self.text_queue = queue.Queue()
|
| 39 |
+
self.captured_text = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
def write(self, text):
|
| 42 |
+
"""Capture written text"""
|
| 43 |
+
if text and text.strip():
|
| 44 |
+
self.captured_text += text
|
| 45 |
+
self.text_queue.put(text)
|
| 46 |
+
return len(text)
|
| 47 |
+
|
| 48 |
+
def flush(self):
|
| 49 |
+
"""Flush method for compatibility"""
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
def get_text(self):
|
| 53 |
+
"""Get all captured text"""
|
| 54 |
+
return self.captured_text
|
| 55 |
+
|
| 56 |
def reset(self):
|
| 57 |
+
"""Reset the capture"""
|
| 58 |
+
self.captured_text = ""
|
|
|
|
| 59 |
while not self.text_queue.empty():
|
| 60 |
try:
|
| 61 |
self.text_queue.get_nowait()
|
|
|
|
| 91 |
# Tokenize input
|
| 92 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 93 |
|
| 94 |
+
# Create stream capture
|
| 95 |
+
stream_capture = StreamCapture()
|
|
|
|
| 96 |
|
| 97 |
+
# Create TextStreamer with our capture
|
| 98 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 99 |
+
|
| 100 |
+
# Temporarily redirect the streamer's output
|
| 101 |
+
original_stdout = sys.stdout
|
| 102 |
+
|
| 103 |
+
# Generation parameters
|
| 104 |
generation_kwargs = {
|
| 105 |
**model_inputs,
|
| 106 |
"max_new_tokens": max_tokens,
|
|
|
|
| 109 |
"do_sample": True,
|
| 110 |
"pad_token_id": tokenizer.eos_token_id,
|
| 111 |
"streamer": streamer,
|
|
|
|
| 112 |
}
|
| 113 |
|
| 114 |
+
# Start generation in a separate thread
|
| 115 |
def generate():
|
| 116 |
try:
|
| 117 |
+
# Redirect stdout to capture streamer output
|
| 118 |
+
sys.stdout = stream_capture
|
| 119 |
with torch.no_grad():
|
| 120 |
model.generate(**generation_kwargs)
|
| 121 |
except Exception as e:
|
| 122 |
+
stream_capture.text_queue.put(f"Error: {str(e)}")
|
| 123 |
+
finally:
|
| 124 |
+
# Restore stdout
|
| 125 |
+
sys.stdout = original_stdout
|
| 126 |
+
stream_capture.text_queue.put(None) # Signal end
|
| 127 |
|
| 128 |
thread = threading.Thread(target=generate)
|
| 129 |
thread.start()
|
|
|
|
| 132 |
generated_text = ""
|
| 133 |
while True:
|
| 134 |
try:
|
| 135 |
+
new_text = stream_capture.text_queue.get(timeout=30)
|
| 136 |
if new_text is None:
|
| 137 |
break
|
| 138 |
generated_text += new_text
|