Spaces:
Build error
Build error
| from langfuse import Langfuse | |
| from langfuse.decorators import observe, langfuse_context | |
| from config.config import settings | |
| import os | |
| # Initialize Langfuse | |
| os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-9f2c32d2-266f-421d-9b87-51377f0a268c" | |
| os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-229e10c5-6210-4a4b-a432-0f17bc66e56c" | |
| os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" # 🇪🇺 EU region | |
| try: | |
| langfuse = Langfuse() | |
| catch: | |
| print("Langfuse Offline") | |
| # model_manager.py | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from llama_cpp import Llama | |
| from typing import Optional, Dict | |
| import logging | |
| from functools import lru_cache | |
| from config.config import GenerationConfig, ModelConfig | |
| class ModelManager: | |
| def __init__(self, device: Optional[str] = None): | |
| self.logger = logging.getLogger(__name__) | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.models: Dict[str, Any] = {} | |
| self.tokenizers: Dict[str, Any] = {} | |
| def load_model(self, model_id: str, model_path: str, model_type: str, config: ModelConfig) -> None: | |
| """Load a model with specified configuration.""" | |
| try: | |
| ##could be differnt models, so we can use a factory pattern to load the correct model - textgen, llama, gguf, text2video, text2image etc. | |
| if model_type == "llama": | |
| self.tokenizers[model_id] = AutoTokenizer.from_pretrained( | |
| model_path, | |
| padding_side='left', | |
| trust_remote_code=True, | |
| **config.tokenizer_kwargs | |
| ) | |
| if self.tokenizers[model_id].pad_token is None: | |
| self.tokenizers[model_id].pad_token = self.tokenizers[model_id].eos_token | |
| self.models[model_id] = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| **config.model_kwargs | |
| ) | |
| elif model_type == "gguf": | |
| #TODO load the model first from the cache, if not found load the model and save it in the cache | |
| #from huggingface_hub import hf_hub_download | |
| #prm_model_path = hf_hub_download( | |
| # repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", | |
| # filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" | |
| #) | |
| self.models[model_id] = self._load_quantized_model( | |
| model_path, | |
| **config.quantization_kwargs | |
| ) | |
| except Exception as e: | |
| self.logger.error(f"Failed to load model {model_id}: {str(e)}") | |
| raise | |
| def unload_model(self, model_id: str) -> None: | |
| """Unload a model and free resources.""" | |
| if model_id in self.models: | |
| del self.models[model_id] | |
| if model_id in self.tokenizers: | |
| del self.tokenizers[model_id] | |
| torch.cuda.empty_cache() | |
| def _load_quantized_model(self, model_path: str, **kwargs) -> Llama: | |
| """Load a quantized GGUF model.""" | |
| try: | |
| n_gpu_layers = -1 if torch.cuda.is_available() else 0 | |
| model = Llama( | |
| model_path=model_path, | |
| n_ctx=kwargs.get('n_ctx', 2048), | |
| n_batch=kwargs.get('n_batch', 512), | |
| n_gpu_layers=kwargs.get('n_gpu_layers', n_gpu_layers), | |
| verbose=kwargs.get('verbose', False) | |
| ) | |
| return model | |
| except Exception as e: | |
| self.logger.error(f"Failed to load GGUF model: {str(e)}") | |
| raise | |
| # cache.py | |
| from functools import lru_cache | |
| from typing import Tuple, Any | |
| # TODO explain howto use the cache | |
| class ResponseCache: | |
| def __init__(self, cache_size: int = 1000): | |
| self.cache_size = cache_size | |
| self._initialize_cache() | |
| def _initialize_cache(self): | |
| def cached_response(prompt: str, config_hash: str) -> Tuple[str, float]: | |
| pass | |
| self.get_cached_response = cached_response | |
| def cache_response(self, prompt: str, config: GenerationConfig, response: str, score: float) -> None: | |
| config_hash = hash(str(config.__dict__)) | |
| self.get_cached_response(prompt, str(config_hash)) | |
| def get_response(self, prompt: str, config: GenerationConfig) -> Optional[Tuple[str, float]]: | |
| config_hash = hash(str(config.__dict__)) | |
| return self.get_cached_response(prompt, str(config_hash)) | |
| # batch_processor.py | |
| from typing import List, Dict | |
| import asyncio | |
| #TODO explain how to use the batch processor | |
| class BatchProcessor: | |
| def __init__(self, max_batch_size: int = 32, max_wait_time: float = 0.1): | |
| self.max_batch_size = max_batch_size | |
| self.max_wait_time = max_wait_time | |
| self.pending_requests: List[Dict] = [] | |
| self.lock = asyncio.Lock() | |
| async def add_request(self, request: Dict) -> Any: | |
| async with self.lock: | |
| self.pending_requests.append(request) | |
| if len(self.pending_requests) >= self.max_batch_size: | |
| return await self._process_batch() | |
| else: | |
| await asyncio.sleep(self.max_wait_time) | |
| if self.pending_requests: | |
| return await self._process_batch() | |
| async def _process_batch(self) -> List[Any]: | |
| batch = self.pending_requests[:self.max_batch_size] | |
| self.pending_requests = self.pending_requests[self.max_batch_size:] | |
| # TODO implement the batch processing logic | |
| return batch | |
| # base_generator.py | |
| from abc import ABC, abstractmethod | |
| from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple | |
| from dataclasses import dataclass | |
| from logging import getLogger | |
| from config.config import GenerationConfig, ModelConfig | |
| class BaseGenerator(ABC): | |
| """Base class for all generator implementations.""" | |
| def __init__( | |
| self, | |
| model_name: str, | |
| device: Optional[str] = None, | |
| default_generation_config: Optional[GenerationConfig] = None, | |
| model_config: Optional[ModelConfig] = None, | |
| cache_size: int = 1000, | |
| max_batch_size: int = 32 | |
| ): | |
| self.logger = getLogger(__name__) | |
| self.model_manager = ModelManager(device) | |
| self.cache = ResponseCache(cache_size) | |
| self.batch_processor = BatchProcessor(max_batch_size) | |
| self.health_check = HealthCheck() | |
| self.default_config = default_generation_config or GenerationConfig() | |
| self.model_config = model_config or ModelConfig() | |
| async def generate_stream( | |
| self, | |
| prompt: str, | |
| config: Optional[GenerationConfig] = None | |
| ) -> AsyncGenerator[str, None]: | |
| pass | |
| def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: | |
| pass | |
| def generate(self, prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
| pass | |
| # strategy.py | |
| #TODO UPDATE Paths | |
| from abc import ABC, abstractmethod | |
| from typing import List, Tuple | |
| class GenerationStrategy(ABC): | |
| """Base class for generation strategies.""" | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
| pass | |
| class DefaultStrategy(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
| input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
| output = generator.model.generate(input_ids, **model_kwargs) | |
| return generator.tokenizer.decode(output[0], skip_special_tokens=True) | |
| class MajorityVotingStrategy(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| outputs = [] | |
| for _ in range(num_samples): | |
| input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
| output = generator.model.generate(input_ids, **model_kwargs) | |
| outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True)) | |
| return max(set(outputs), key=outputs.count) | |
| class BestOfN(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| scored_outputs = [] | |
| for _ in range(num_samples): | |
| input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
| output = self.llama_model.generate(input_ids, **model_kwargs) | |
| response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
| score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() | |
| scored_outputs.append((response, score)) | |
| return max(scored_outputs, key=lambda x: x[1])[0] | |
| class BeamSearch(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
| outputs = self.llama_model.generate( | |
| input_ids, | |
| num_beams=num_samples, | |
| num_return_sequences=num_samples, | |
| **model_kwargs | |
| ) | |
| return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
| class DVT(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| results = [] | |
| for _ in range(breadth): | |
| input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
| output = self.llama_model.generate(input_ids, **model_kwargs) | |
| response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
| score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() | |
| results.append((response, score)) | |
| for _ in range(depth - 1): | |
| best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] | |
| for response, _ in best_responses: | |
| input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device) | |
| output = self.llama_model.generate(input_ids, **model_kwargs) | |
| extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
| score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item() | |
| results.append((extended_response, score)) | |
| return max(results, key=lambda x: x[1])[0] | |
| class COT(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| #TODO implement the chain of thought strategy | |
| return "Not implemented yet" | |
| class ReAct(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| #TODO implement the ReAct framework | |
| return "Not implemented yet" | |
| # Add other strategy implementations... | |
| # prompt_builder.py | |
| from typing import Protocol, List, Tuple | |
| from transformers import AutoTokenizer | |
| class PromptTemplate(Protocol): | |
| """Protocol for prompt templates.""" | |
| def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: | |
| pass | |
| class LlamaPromptTemplate: | |
| def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str: | |
| system_message = f"Please assist based on the following context: {context}" | |
| prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>" | |
| for user_msg, assistant_msg in chat_history[-max_history_turns:]: | |
| prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" | |
| prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" | |
| prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>" | |
| prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" | |
| return prompt | |
| class TransformersPromptTemplate: | |
| def __init__(self, model_path: str): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": f"Please assist based on the following context: {context}", | |
| } | |
| ] | |
| for user_msg, assistant_msg in chat_history: | |
| messages.extend([ | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": assistant_msg} | |
| ]) | |
| messages.append({"role": "user", "content": user_input}) | |
| tokenized_chat = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return tokenized_chat | |
| # health_check.py | |
| import psutil | |
| from dataclasses import dataclass | |
| from typing import Dict, Any | |
| class HealthStatus: | |
| status: str | |
| gpu_memory: Dict[str, float] | |
| cpu_usage: float | |
| ram_usage: float | |
| model_status: Dict[str, str] | |
| class HealthCheck: | |
| def check_gpu_memory() -> Dict[str, float]: | |
| if torch.cuda.is_available(): | |
| return { | |
| f"gpu_{i}": torch.cuda.memory_allocated(i) / 1024**3 | |
| for i in range(torch.cuda.device_count()) | |
| } | |
| return {} | |
| def check_system_resources() -> HealthStatus: | |
| return HealthStatus( | |
| status="healthy", | |
| gpu_memory=HealthCheck.check_gpu_memory(), | |
| cpu_usage=psutil.cpu_percent(), | |
| ram_usage=psutil.virtual_memory().percent, | |
| #TODO add more system resources like disk, network, etc. | |
| model_status={} # To be filled by the model manager | |
| ) | |
| # llama_generator.py | |
| from config.config import GenerationConfig, ModelConfig | |
| class LlamaGenerator(BaseGenerator): | |
| def __init__( | |
| self, | |
| llama_model_name: str, | |
| prm_model_path: str, | |
| device: Optional[str] = None, | |
| default_generation_config: Optional[GenerationConfig] = None, | |
| model_config: Optional[ModelConfig] = None, | |
| cache_size: int = 1000, | |
| max_batch_size: int = 32 | |
| ): | |
| super().__init__( | |
| llama_model_name, | |
| device, | |
| default_generation_config, | |
| model_config, | |
| cache_size, | |
| max_batch_size | |
| ) | |
| # Initialize models | |
| self.model_manager.load_model( | |
| "llama", | |
| llama_model_name, | |
| "llama", | |
| self.model_config | |
| ) | |
| self.model_manager.load_model( | |
| "prm", | |
| prm_model_path, | |
| "gguf", | |
| self.model_config | |
| ) | |
| self.prompt_builder = LlamaPromptTemplate() | |
| self._init_strategies() | |
| def _init_strategies(self): | |
| self.strategies = { | |
| "default": DefaultStrategy(), | |
| "majority_voting": MajorityVotingStrategy(), | |
| "best_of_n": BestOfN(), | |
| "beam_search": BeamSearch(), | |
| "dvts": DVT(), | |
| } | |
| def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: | |
| """Get generation kwargs based on config.""" | |
| return { | |
| key: getattr(config, key) | |
| for key in [ | |
| "max_new_tokens", | |
| "temperature", | |
| "top_p", | |
| "top_k", | |
| "repetition_penalty", | |
| "length_penalty", | |
| "do_sample" | |
| ] | |
| if hasattr(config, key) | |
| } | |
| def generate( | |
| self, | |
| prompt: str, | |
| model_kwargs: Dict[str, Any], | |
| strategy: str = "default", | |
| **kwargs | |
| ) -> str: | |
| if strategy not in self.strategies: | |
| raise ValueError(f"Unknown strategy: {strategy}") | |
| return self.strategies[strategy].generate( | |
| self, | |
| prompt, | |
| model_kwargs, | |
| **kwargs | |
| ) | |
| def check_health(self) -> HealthStatus: | |
| """Check the health status of the generator.""" | |
| return self.health_check.check_system_resources() # TODO add model status | |
| ################### | |
| ################# | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, WebSocket, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field, ConfigDict | |
| from typing import List, Optional, Dict, Any, AsyncGenerator | |
| import asyncio | |
| import uuid | |
| from datetime import datetime | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| from contextlib import asynccontextmanager | |
| class ChatMessage(BaseModel): | |
| """A single message in the chat history.""" | |
| role: str = Field( | |
| ..., | |
| description="Role of the message sender", | |
| examples=["user", "assistant"] | |
| ) | |
| content: str = Field(..., description="Content of the message") | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "role": "user", | |
| "content": "What is the capital of France?" | |
| } | |
| } | |
| ) | |
| class GenerationConfig(BaseModel): | |
| """Configuration for text generation.""" | |
| temperature: float = Field( | |
| 0.7, | |
| ge=0.0, | |
| le=2.0, | |
| description="Controls randomness in the output. Higher values (e.g., 0.8) make the output more random, lower values (e.g., 0.2) make it more focused and deterministic." | |
| ) | |
| max_new_tokens: int = Field( | |
| 100, | |
| ge=1, | |
| le=2048, | |
| description="Maximum number of tokens to generate" | |
| ) | |
| top_p: float = Field( | |
| 0.9, | |
| ge=0.0, | |
| le=1.0, | |
| description="Nucleus sampling parameter. Only tokens with cumulative probability < top_p are considered." | |
| ) | |
| top_k: int = Field( | |
| 50, | |
| ge=0, | |
| description="Only consider the top k tokens for text generation" | |
| ) | |
| strategy: str = Field( | |
| "default", | |
| description="Generation strategy to use", | |
| examples=["default", "majority_voting", "best_of_n", "beam_search", "dvts"] | |
| ) | |
| num_samples: int = Field( | |
| 5, | |
| ge=1, | |
| le=10, | |
| description="Number of samples to generate (used in majority_voting and best_of_n strategies)" | |
| ) | |
| class GenerationRequest(BaseModel): | |
| """Request model for text generation.""" | |
| context: Optional[str] = Field( | |
| None, | |
| description="Additional context to guide the generation", | |
| examples=["You are a helpful assistant skilled in Python programming"] | |
| ) | |
| messages: List[ChatMessage] = Field( | |
| ..., | |
| description="Chat history including the current message", | |
| min_items=1 | |
| ) | |
| config: Optional[GenerationConfig] = Field( | |
| None, | |
| description="Generation configuration parameters" | |
| ) | |
| stream: bool = Field( | |
| False, | |
| description="Whether to stream the response token by token" | |
| ) | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "context": "You are a helpful assistant", | |
| "messages": [ | |
| {"role": "user", "content": "What is the capital of France?"} | |
| ], | |
| "config": { | |
| "temperature": 0.7, | |
| "max_new_tokens": 100 | |
| }, | |
| "stream": False | |
| } | |
| } | |
| ) | |
| class GenerationResponse(BaseModel): | |
| """Response model for text generation.""" | |
| id: str = Field(..., description="Unique generation ID") | |
| content: str = Field(..., description="Generated text content") | |
| created_at: datetime = Field( | |
| default_factory=datetime.now, | |
| description="Timestamp of generation" | |
| ) | |
| # Model and cache management | |
| async def get_prm_model_path(): | |
| """Download and cache the PRM model.""" | |
| return await asyncio.to_thread( | |
| hf_hub_download, | |
| repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", | |
| filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" | |
| ) | |
| async def lifespan(app: FastAPI): | |
| """Lifecycle management for the FastAPI application.""" | |
| # Startup: Initialize generator | |
| global generator | |
| try: | |
| prm_model_path = await get_prm_model_path() | |
| generator = LlamaGenerator( | |
| llama_model_name="meta-llama/Llama-3.2-1B-Instruct", | |
| prm_model_path=prm_model_path, | |
| default_generation_config=GenerationConfig( | |
| max_new_tokens=100, | |
| temperature=0.7 | |
| ) | |
| ) | |
| yield | |
| finally: | |
| # Shutdown: Clean up resources | |
| if generator: | |
| await asyncio.to_thread(generator.cleanup) | |
| # FastAPI application | |
| app = FastAPI( | |
| title="Inference Deluxe Service", | |
| description=""" | |
| A service for generating text using LLaMA models with various generation strategies. | |
| Generation Strategies: | |
| - default: Standard autoregressive generation | |
| - majority_voting: Generates multiple responses and selects the most common one | |
| - best_of_n: Generates multiple responses and selects the best based on a scoring metric | |
| - beam_search: Uses beam search for more coherent text generation | |
| - dvts: Dynamic vocabulary tree search for efficient generation | |
| """, | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def get_generator(): | |
| """Dependency to get the generator instance.""" | |
| if not generator: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Generator not initialized" | |
| ) | |
| return generator | |
| async def generate( | |
| request: GenerationRequest, | |
| generator: Any = Depends(get_generator) | |
| ): | |
| """ | |
| Generate a text response based on the provided context and chat history. | |
| The generation process can be customized using various parameters in the config: | |
| - temperature: Controls randomness (0.0 to 2.0) | |
| - max_new_tokens: Maximum length of generated text | |
| - top_p: Nucleus sampling parameter | |
| - top_k: Top-k sampling parameter | |
| - strategy: Generation strategy to use | |
| - num_samples: Number of samples for applicable strategies | |
| Generation Strategies: | |
| - default: Standard generation | |
| - majority_voting: Generates multiple responses and uses the most common one | |
| - best_of_n: Generates multiple responses and picks the best | |
| - beam_search: Uses beam search for coherent generation | |
| - dvts: Dynamic vocabulary tree search | |
| """ | |
| try: | |
| chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] | |
| user_input = request.messages[-1].content | |
| config = request.config or GenerationConfig() | |
| response = await asyncio.to_thread( | |
| generator.generate_with_context, | |
| context=request.context or "", | |
| user_input=user_input, | |
| chat_history=chat_history, | |
| config=config | |
| ) | |
| return GenerationResponse( | |
| id=str(uuid.uuid4()), | |
| content=response | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_stream( | |
| websocket: WebSocket, | |
| generator: Any = Depends(get_generator) | |
| ): | |
| """ | |
| Stream generated text tokens over a WebSocket connection. | |
| The stream sends JSON messages with the following format: | |
| - During generation: {"token": "generated_token", "finished": false} | |
| - End of generation: {"token": "", "finished": true} | |
| - Error: {"error": "error_message"} | |
| """ | |
| await websocket.accept() | |
| try: | |
| while True: | |
| request_data = await websocket.receive_text() | |
| request = GenerationRequest.parse_raw(request_data) | |
| chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] | |
| user_input = request.messages[-1].content | |
| config = request.config or GenerationConfig() | |
| async for token in generator.generate_stream( | |
| prompt=generator._construct_prompt( | |
| context=request.context or "", | |
| user_input=user_input, | |
| chat_history=chat_history | |
| ), | |
| config=config | |
| ): | |
| await websocket.send_text(json.dumps({ | |
| "token": token, | |
| "finished": False | |
| })) | |
| await websocket.send_text(json.dumps({ | |
| "token": "", | |
| "finished": True | |
| })) | |
| except Exception as e: | |
| await websocket.send_text(json.dumps({ | |
| "error": str(e) | |
| })) | |
| finally: | |
| await websocket.close() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |