| |
|
| | from langfuse import Langfuse |
| | from langfuse.decorators import observe, langfuse_context |
| |
|
| | from config.config import settings |
| | import os |
| |
|
| | |
| | os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-9f2c32d2-266f-421d-9b87-51377f0a268c" |
| | os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-229e10c5-6210-4a4b-a432-0f17bc66e56c" |
| | os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" |
| |
|
| | try: |
| | langfuse = Langfuse() |
| | catch: |
| | print("Langfuse Offline") |
| |
|
| |
|
| |
|
| |
|
| | |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from llama_cpp import Llama |
| | from typing import Optional, Dict |
| | import logging |
| | from functools import lru_cache |
| | from config.config import GenerationConfig, ModelConfig |
| |
|
| |
|
| | class ModelManager: |
| | def __init__(self, device: Optional[str] = None): |
| | self.logger = logging.getLogger(__name__) |
| | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| | self.models: Dict[str, Any] = {} |
| | self.tokenizers: Dict[str, Any] = {} |
| |
|
| | def load_model(self, model_id: str, model_path: str, model_type: str, config: ModelConfig) -> None: |
| | """Load a model with specified configuration.""" |
| | try: |
| | |
| | if model_type == "llama": |
| | self.tokenizers[model_id] = AutoTokenizer.from_pretrained( |
| | model_path, |
| | padding_side='left', |
| | trust_remote_code=True, |
| | **config.tokenizer_kwargs |
| | ) |
| | if self.tokenizers[model_id].pad_token is None: |
| | self.tokenizers[model_id].pad_token = self.tokenizers[model_id].eos_token |
| | |
| | self.models[model_id] = AutoModelForCausalLM.from_pretrained( |
| | model_path, |
| | device_map="auto", |
| | trust_remote_code=True, |
| | **config.model_kwargs |
| | ) |
| | elif model_type == "gguf": |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | self.models[model_id] = self._load_quantized_model( |
| | model_path, |
| | **config.quantization_kwargs |
| | ) |
| | except Exception as e: |
| | self.logger.error(f"Failed to load model {model_id}: {str(e)}") |
| | raise |
| |
|
| | def unload_model(self, model_id: str) -> None: |
| | """Unload a model and free resources.""" |
| | if model_id in self.models: |
| | del self.models[model_id] |
| | if model_id in self.tokenizers: |
| | del self.tokenizers[model_id] |
| | torch.cuda.empty_cache() |
| |
|
| | def _load_quantized_model(self, model_path: str, **kwargs) -> Llama: |
| | """Load a quantized GGUF model.""" |
| | try: |
| | n_gpu_layers = -1 if torch.cuda.is_available() else 0 |
| | model = Llama( |
| | model_path=model_path, |
| | n_ctx=kwargs.get('n_ctx', 2048), |
| | n_batch=kwargs.get('n_batch', 512), |
| | n_gpu_layers=kwargs.get('n_gpu_layers', n_gpu_layers), |
| | verbose=kwargs.get('verbose', False) |
| | ) |
| | return model |
| | except Exception as e: |
| | self.logger.error(f"Failed to load GGUF model: {str(e)}") |
| | raise |
| |
|
| |
|
| | |
| | from functools import lru_cache |
| | from typing import Tuple, Any |
| |
|
| | |
| | class ResponseCache: |
| | def __init__(self, cache_size: int = 1000): |
| | self.cache_size = cache_size |
| | self._initialize_cache() |
| |
|
| | def _initialize_cache(self): |
| | @lru_cache(maxsize=self.cache_size) |
| | def cached_response(prompt: str, config_hash: str) -> Tuple[str, float]: |
| | pass |
| | self.get_cached_response = cached_response |
| |
|
| | def cache_response(self, prompt: str, config: GenerationConfig, response: str, score: float) -> None: |
| | config_hash = hash(str(config.__dict__)) |
| | self.get_cached_response(prompt, str(config_hash)) |
| | |
| | def get_response(self, prompt: str, config: GenerationConfig) -> Optional[Tuple[str, float]]: |
| | config_hash = hash(str(config.__dict__)) |
| | return self.get_cached_response(prompt, str(config_hash)) |
| |
|
| |
|
| | |
| | from typing import List, Dict |
| | import asyncio |
| |
|
| | |
| | class BatchProcessor: |
| | def __init__(self, max_batch_size: int = 32, max_wait_time: float = 0.1): |
| | self.max_batch_size = max_batch_size |
| | self.max_wait_time = max_wait_time |
| | self.pending_requests: List[Dict] = [] |
| | self.lock = asyncio.Lock() |
| |
|
| | async def add_request(self, request: Dict) -> Any: |
| | async with self.lock: |
| | self.pending_requests.append(request) |
| | if len(self.pending_requests) >= self.max_batch_size: |
| | return await self._process_batch() |
| | else: |
| | await asyncio.sleep(self.max_wait_time) |
| | if self.pending_requests: |
| | return await self._process_batch() |
| |
|
| | async def _process_batch(self) -> List[Any]: |
| | batch = self.pending_requests[:self.max_batch_size] |
| | self.pending_requests = self.pending_requests[self.max_batch_size:] |
| | |
| | return batch |
| |
|
| |
|
| |
|
| | |
| | from abc import ABC, abstractmethod |
| | from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple |
| | from dataclasses import dataclass |
| | from logging import getLogger |
| |
|
| |
|
| | from config.config import GenerationConfig, ModelConfig |
| |
|
| | class BaseGenerator(ABC): |
| | """Base class for all generator implementations.""" |
| | |
| | def __init__( |
| | self, |
| | model_name: str, |
| | device: Optional[str] = None, |
| | default_generation_config: Optional[GenerationConfig] = None, |
| | model_config: Optional[ModelConfig] = None, |
| | cache_size: int = 1000, |
| | max_batch_size: int = 32 |
| | ): |
| | self.logger = getLogger(__name__) |
| | self.model_manager = ModelManager(device) |
| | self.cache = ResponseCache(cache_size) |
| | self.batch_processor = BatchProcessor(max_batch_size) |
| | self.health_check = HealthCheck() |
| | |
| | self.default_config = default_generation_config or GenerationConfig() |
| | self.model_config = model_config or ModelConfig() |
| | |
| | @abstractmethod |
| | async def generate_stream( |
| | self, |
| | prompt: str, |
| | config: Optional[GenerationConfig] = None |
| | ) -> AsyncGenerator[str, None]: |
| | pass |
| | |
| | @abstractmethod |
| | def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: |
| | pass |
| | |
| | @abstractmethod |
| | def generate(self, prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: |
| | pass |
| |
|
| | |
| | |
| | from abc import ABC, abstractmethod |
| | from typing import List, Tuple |
| |
|
| | class GenerationStrategy(ABC): |
| | """Base class for generation strategies.""" |
| | |
| | @abstractmethod |
| | def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: |
| | pass |
| |
|
| | class DefaultStrategy(GenerationStrategy): |
| | def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: |
| | input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) |
| | output = generator.model.generate(input_ids, **model_kwargs) |
| | return generator.tokenizer.decode(output[0], skip_special_tokens=True) |
| |
|
| | class MajorityVotingStrategy(GenerationStrategy): |
| | def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| | outputs = [] |
| | for _ in range(num_samples): |
| | input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) |
| | output = generator.model.generate(input_ids, **model_kwargs) |
| | outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True)) |
| | return max(set(outputs), key=outputs.count) |
| |
|
| | class BestOfN(GenerationStrategy): |
| | def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| | scored_outputs = [] |
| | for _ in range(num_samples): |
| | input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
| | output = self.llama_model.generate(input_ids, **model_kwargs) |
| | response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
| | score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() |
| | scored_outputs.append((response, score)) |
| | return max(scored_outputs, key=lambda x: x[1])[0] |
| | |
| | class BeamSearch(GenerationStrategy): |
| | def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| | input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
| | outputs = self.llama_model.generate( |
| | input_ids, |
| | num_beams=num_samples, |
| | num_return_sequences=num_samples, |
| | **model_kwargs |
| | ) |
| | return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] |
| |
|
| | class DVT(GenerationStrategy): |
| | def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| | results = [] |
| | for _ in range(breadth): |
| | input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
| | output = self.llama_model.generate(input_ids, **model_kwargs) |
| | response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
| | score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() |
| | results.append((response, score)) |
| | |
| | for _ in range(depth - 1): |
| | best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] |
| | for response, _ in best_responses: |
| | input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device) |
| | output = self.llama_model.generate(input_ids, **model_kwargs) |
| | extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) |
| | score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item() |
| | results.append((extended_response, score)) |
| | return max(results, key=lambda x: x[1])[0] |
| | |
| | class COT(GenerationStrategy): |
| | def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| | |
| | |
| | return "Not implemented yet" |
| | |
| | class ReAct(GenerationStrategy): |
| | def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| | |
| | return "Not implemented yet" |
| | |
| |
|
| | |
| | from typing import Protocol, List, Tuple |
| | from transformers import AutoTokenizer |
| |
|
| | class PromptTemplate(Protocol): |
| | """Protocol for prompt templates.""" |
| | def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: |
| | pass |
| |
|
| | class LlamaPromptTemplate: |
| | def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str: |
| | system_message = f"Please assist based on the following context: {context}" |
| | prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>" |
| | |
| | for user_msg, assistant_msg in chat_history[-max_history_turns:]: |
| | prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" |
| | prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" |
| | |
| | prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>" |
| | prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" |
| | return prompt |
| |
|
| | class TransformersPromptTemplate: |
| | def __init__(self, model_path: str): |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | |
| | def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: |
| | messages = [ |
| | { |
| | "role": "system", |
| | "content": f"Please assist based on the following context: {context}", |
| | } |
| | ] |
| | |
| | for user_msg, assistant_msg in chat_history: |
| | messages.extend([ |
| | {"role": "user", "content": user_msg}, |
| | {"role": "assistant", "content": assistant_msg} |
| | ]) |
| | |
| | messages.append({"role": "user", "content": user_input}) |
| | |
| | tokenized_chat = self.tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | return tokenized_chat |
| |
|
| | |
| | import psutil |
| | from dataclasses import dataclass |
| | from typing import Dict, Any |
| |
|
| | @dataclass |
| | class HealthStatus: |
| | status: str |
| | gpu_memory: Dict[str, float] |
| | cpu_usage: float |
| | ram_usage: float |
| | model_status: Dict[str, str] |
| |
|
| | class HealthCheck: |
| | @staticmethod |
| | def check_gpu_memory() -> Dict[str, float]: |
| | if torch.cuda.is_available(): |
| | return { |
| | f"gpu_{i}": torch.cuda.memory_allocated(i) / 1024**3 |
| | for i in range(torch.cuda.device_count()) |
| | } |
| | return {} |
| |
|
| | @staticmethod |
| | def check_system_resources() -> HealthStatus: |
| | return HealthStatus( |
| | status="healthy", |
| | gpu_memory=HealthCheck.check_gpu_memory(), |
| | cpu_usage=psutil.cpu_percent(), |
| | ram_usage=psutil.virtual_memory().percent, |
| | |
| | model_status={} |
| | ) |
| |
|
| |
|
| | |
| | from config.config import GenerationConfig, ModelConfig |
| |
|
| | class LlamaGenerator(BaseGenerator): |
| | def __init__( |
| | self, |
| | llama_model_name: str, |
| | prm_model_path: str, |
| | device: Optional[str] = None, |
| | default_generation_config: Optional[GenerationConfig] = None, |
| | model_config: Optional[ModelConfig] = None, |
| | cache_size: int = 1000, |
| | max_batch_size: int = 32 |
| | ): |
| | super().__init__( |
| | llama_model_name, |
| | device, |
| | default_generation_config, |
| | model_config, |
| | cache_size, |
| | max_batch_size |
| | ) |
| | |
| | |
| | self.model_manager.load_model( |
| | "llama", |
| | llama_model_name, |
| | "llama", |
| | self.model_config |
| | ) |
| | self.model_manager.load_model( |
| | "prm", |
| | prm_model_path, |
| | "gguf", |
| | self.model_config |
| | ) |
| | |
| | self.prompt_builder = LlamaPromptTemplate() |
| | self._init_strategies() |
| | |
| | def _init_strategies(self): |
| | self.strategies = { |
| | "default": DefaultStrategy(), |
| | "majority_voting": MajorityVotingStrategy(), |
| | "best_of_n": BestOfN(), |
| | "beam_search": BeamSearch(), |
| | "dvts": DVT(), |
| | } |
| | |
| | def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: |
| | """Get generation kwargs based on config.""" |
| | return { |
| | key: getattr(config, key) |
| | for key in [ |
| | "max_new_tokens", |
| | "temperature", |
| | "top_p", |
| | "top_k", |
| | "repetition_penalty", |
| | "length_penalty", |
| | "do_sample" |
| | ] |
| | if hasattr(config, key) |
| | } |
| | |
| | def generate( |
| | self, |
| | prompt: str, |
| | model_kwargs: Dict[str, Any], |
| | strategy: str = "default", |
| | **kwargs |
| | ) -> str: |
| | if strategy not in self.strategies: |
| | raise ValueError(f"Unknown strategy: {strategy}") |
| | |
| | return self.strategies[strategy].generate( |
| | self, |
| | prompt, |
| | model_kwargs, |
| | **kwargs |
| | ) |
| | |
| | def check_health(self) -> HealthStatus: |
| | """Check the health status of the generator.""" |
| | return self.health_check.check_system_resources() |
| | |
| | |
| | |
| | |
| | |
| | from fastapi import FastAPI, HTTPException, BackgroundTasks, WebSocket, Depends |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import StreamingResponse |
| | from pydantic import BaseModel, Field, ConfigDict |
| | from typing import List, Optional, Dict, Any, AsyncGenerator |
| | import asyncio |
| | import uuid |
| | from datetime import datetime |
| | import json |
| | from huggingface_hub import hf_hub_download |
| | from contextlib import asynccontextmanager |
| |
|
| | |
| |
|
| | class ChatMessage(BaseModel): |
| | """A single message in the chat history.""" |
| | role: str = Field( |
| | ..., |
| | description="Role of the message sender", |
| | examples=["user", "assistant"] |
| | ) |
| | content: str = Field(..., description="Content of the message") |
| | |
| | model_config = ConfigDict( |
| | json_schema_extra={ |
| | "example": { |
| | "role": "user", |
| | "content": "What is the capital of France?" |
| | } |
| | } |
| | ) |
| | |
| | class GenerationConfig(BaseModel): |
| | """Configuration for text generation.""" |
| | temperature: float = Field( |
| | 0.7, |
| | ge=0.0, |
| | le=2.0, |
| | description="Controls randomness in the output. Higher values (e.g., 0.8) make the output more random, lower values (e.g., 0.2) make it more focused and deterministic." |
| | ) |
| | max_new_tokens: int = Field( |
| | 100, |
| | ge=1, |
| | le=2048, |
| | description="Maximum number of tokens to generate" |
| | ) |
| | top_p: float = Field( |
| | 0.9, |
| | ge=0.0, |
| | le=1.0, |
| | description="Nucleus sampling parameter. Only tokens with cumulative probability < top_p are considered." |
| | ) |
| | top_k: int = Field( |
| | 50, |
| | ge=0, |
| | description="Only consider the top k tokens for text generation" |
| | ) |
| | strategy: str = Field( |
| | "default", |
| | description="Generation strategy to use", |
| | examples=["default", "majority_voting", "best_of_n", "beam_search", "dvts"] |
| | ) |
| | num_samples: int = Field( |
| | 5, |
| | ge=1, |
| | le=10, |
| | description="Number of samples to generate (used in majority_voting and best_of_n strategies)" |
| | ) |
| |
|
| | class GenerationRequest(BaseModel): |
| | """Request model for text generation.""" |
| | context: Optional[str] = Field( |
| | None, |
| | description="Additional context to guide the generation", |
| | examples=["You are a helpful assistant skilled in Python programming"] |
| | ) |
| | messages: List[ChatMessage] = Field( |
| | ..., |
| | description="Chat history including the current message", |
| | min_items=1 |
| | ) |
| | config: Optional[GenerationConfig] = Field( |
| | None, |
| | description="Generation configuration parameters" |
| | ) |
| | stream: bool = Field( |
| | False, |
| | description="Whether to stream the response token by token" |
| | ) |
| | |
| | model_config = ConfigDict( |
| | json_schema_extra={ |
| | "example": { |
| | "context": "You are a helpful assistant", |
| | "messages": [ |
| | {"role": "user", "content": "What is the capital of France?"} |
| | ], |
| | "config": { |
| | "temperature": 0.7, |
| | "max_new_tokens": 100 |
| | }, |
| | "stream": False |
| | } |
| | } |
| | ) |
| |
|
| | class GenerationResponse(BaseModel): |
| | """Response model for text generation.""" |
| | id: str = Field(..., description="Unique generation ID") |
| | content: str = Field(..., description="Generated text content") |
| | created_at: datetime = Field( |
| | default_factory=datetime.now, |
| | description="Timestamp of generation" |
| | ) |
| | |
| | |
| | |
| | async def get_prm_model_path(): |
| | """Download and cache the PRM model.""" |
| | return await asyncio.to_thread( |
| | hf_hub_download, |
| | repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", |
| | filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" |
| | ) |
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | """Lifecycle management for the FastAPI application.""" |
| | |
| | global generator |
| | try: |
| | prm_model_path = await get_prm_model_path() |
| | generator = LlamaGenerator( |
| | llama_model_name="meta-llama/Llama-3.2-1B-Instruct", |
| | prm_model_path=prm_model_path, |
| | default_generation_config=GenerationConfig( |
| | max_new_tokens=100, |
| | temperature=0.7 |
| | ) |
| | ) |
| | yield |
| | finally: |
| | |
| | if generator: |
| | await asyncio.to_thread(generator.cleanup) |
| |
|
| | |
| | app = FastAPI( |
| | title="Inference Deluxe Service", |
| | description=""" |
| | A service for generating text using LLaMA models with various generation strategies. |
| | |
| | Generation Strategies: |
| | - default: Standard autoregressive generation |
| | - majority_voting: Generates multiple responses and selects the most common one |
| | - best_of_n: Generates multiple responses and selects the best based on a scoring metric |
| | - beam_search: Uses beam search for more coherent text generation |
| | - dvts: Dynamic vocabulary tree search for efficient generation |
| | """, |
| | version="1.0.0", |
| | lifespan=lifespan |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | async def get_generator(): |
| | """Dependency to get the generator instance.""" |
| | if not generator: |
| | raise HTTPException( |
| | status_code=503, |
| | detail="Generator not initialized" |
| | ) |
| | return generator |
| |
|
| | @app.post( |
| | "/generate", |
| | response_model=GenerationResponse, |
| | tags=["generation"], |
| | summary="Generate text response", |
| | response_description="Generated text with unique identifier" |
| | ) |
| | async def generate( |
| | request: GenerationRequest, |
| | generator: Any = Depends(get_generator) |
| | ): |
| | """ |
| | Generate a text response based on the provided context and chat history. |
| | |
| | The generation process can be customized using various parameters in the config: |
| | - temperature: Controls randomness (0.0 to 2.0) |
| | - max_new_tokens: Maximum length of generated text |
| | - top_p: Nucleus sampling parameter |
| | - top_k: Top-k sampling parameter |
| | - strategy: Generation strategy to use |
| | - num_samples: Number of samples for applicable strategies |
| | |
| | Generation Strategies: |
| | - default: Standard generation |
| | - majority_voting: Generates multiple responses and uses the most common one |
| | - best_of_n: Generates multiple responses and picks the best |
| | - beam_search: Uses beam search for coherent generation |
| | - dvts: Dynamic vocabulary tree search |
| | """ |
| | try: |
| | chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] |
| | user_input = request.messages[-1].content |
| | |
| | config = request.config or GenerationConfig() |
| | |
| | response = await asyncio.to_thread( |
| | generator.generate_with_context, |
| | context=request.context or "", |
| | user_input=user_input, |
| | chat_history=chat_history, |
| | config=config |
| | ) |
| | |
| | return GenerationResponse( |
| | id=str(uuid.uuid4()), |
| | content=response |
| | ) |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.websocket("/generate/stream") |
| | async def generate_stream( |
| | websocket: WebSocket, |
| | generator: Any = Depends(get_generator) |
| | ): |
| | """ |
| | Stream generated text tokens over a WebSocket connection. |
| | |
| | The stream sends JSON messages with the following format: |
| | - During generation: {"token": "generated_token", "finished": false} |
| | - End of generation: {"token": "", "finished": true} |
| | - Error: {"error": "error_message"} |
| | """ |
| | await websocket.accept() |
| | |
| | try: |
| | while True: |
| | request_data = await websocket.receive_text() |
| | request = GenerationRequest.parse_raw(request_data) |
| | |
| | chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] |
| | user_input = request.messages[-1].content |
| | |
| | config = request.config or GenerationConfig() |
| | |
| | async for token in generator.generate_stream( |
| | prompt=generator._construct_prompt( |
| | context=request.context or "", |
| | user_input=user_input, |
| | chat_history=chat_history |
| | ), |
| | config=config |
| | ): |
| | await websocket.send_text(json.dumps({ |
| | "token": token, |
| | "finished": False |
| | })) |
| | |
| | await websocket.send_text(json.dumps({ |
| | "token": "", |
| | "finished": True |
| | })) |
| | |
| | except Exception as e: |
| | await websocket.send_text(json.dumps({ |
| | "error": str(e) |
| | })) |
| | finally: |
| | await websocket.close() |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|
| |
|