Spaces:
Paused
Paused
Commit
·
50e2fd2
1
Parent(s):
14d86a4
Changed Generate stream to async
Browse files- main/api.py +7 -4
main/api.py
CHANGED
|
@@ -2,7 +2,8 @@ import os
|
|
| 2 |
from pathlib import Path
|
| 3 |
from threading import Thread
|
| 4 |
import torch
|
| 5 |
-
from typing import Optional,
|
|
|
|
| 6 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
| 7 |
from .utils.logging import setup_logger
|
| 8 |
|
|
@@ -248,12 +249,12 @@ class LLMApi:
|
|
| 248 |
self.logger.error(f"Error generating response: {str(e)}")
|
| 249 |
raise
|
| 250 |
|
| 251 |
-
def generate_stream(
|
| 252 |
self,
|
| 253 |
prompt: str,
|
| 254 |
system_message: Optional[str] = None,
|
| 255 |
max_new_tokens: Optional[int] = None
|
| 256 |
-
) ->
|
| 257 |
"""
|
| 258 |
Generate a streaming response for the given prompt.
|
| 259 |
"""
|
|
@@ -287,10 +288,12 @@ class LLMApi:
|
|
| 287 |
thread = Thread(target=self.generation_model.generate, kwargs=generation_kwargs)
|
| 288 |
thread.start()
|
| 289 |
|
| 290 |
-
#
|
| 291 |
for new_text in streamer:
|
| 292 |
self.logger.debug(f"Generated chunk: {new_text[:50]}...")
|
| 293 |
yield new_text
|
|
|
|
|
|
|
| 294 |
|
| 295 |
except Exception as e:
|
| 296 |
self.logger.error(f"Error in streaming generation: {str(e)}")
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from threading import Thread
|
| 4 |
import torch
|
| 5 |
+
from typing import Optional, List, AsyncIterator
|
| 6 |
+
import asyncio
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
| 8 |
from .utils.logging import setup_logger
|
| 9 |
|
|
|
|
| 249 |
self.logger.error(f"Error generating response: {str(e)}")
|
| 250 |
raise
|
| 251 |
|
| 252 |
+
async def generate_stream(
|
| 253 |
self,
|
| 254 |
prompt: str,
|
| 255 |
system_message: Optional[str] = None,
|
| 256 |
max_new_tokens: Optional[int] = None
|
| 257 |
+
) -> AsyncIterator[str]:
|
| 258 |
"""
|
| 259 |
Generate a streaming response for the given prompt.
|
| 260 |
"""
|
|
|
|
| 288 |
thread = Thread(target=self.generation_model.generate, kwargs=generation_kwargs)
|
| 289 |
thread.start()
|
| 290 |
|
| 291 |
+
# Use async generator to yield chunks
|
| 292 |
for new_text in streamer:
|
| 293 |
self.logger.debug(f"Generated chunk: {new_text[:50]}...")
|
| 294 |
yield new_text
|
| 295 |
+
# Add a small delay to allow other tasks to run
|
| 296 |
+
await asyncio.sleep(0)
|
| 297 |
|
| 298 |
except Exception as e:
|
| 299 |
self.logger.error(f"Error in streaming generation: {str(e)}")
|