Spaces:
Build error
Build error
| from langfuse import Langfuse | |
| from langfuse.decorators import observe, langfuse_context | |
| from config.config import settings | |
| from services.llama_generator import LlamaGenerator | |
| import os | |
| # Initialize Langfuse | |
| os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-04d2302a-aa5c-4870-9703-58ab64c3bcae" | |
| os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-d34ea200-feec-428e-a621-784fce93a5af" | |
| os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" # 🇪🇺 EU region | |
| try: | |
| langfuse = Langfuse() | |
| except Exception as e: | |
| print("Langfuse Offline") | |
| ################### | |
| ################# | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, WebSocket, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field, ConfigDict | |
| from typing import List, Optional, Dict, Any, AsyncGenerator | |
| import asyncio | |
| import uuid | |
| from datetime import datetime | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| from contextlib import asynccontextmanager | |
| class ChatMessage(BaseModel): | |
| """A single message in the chat history.""" | |
| role: str = Field( | |
| ..., | |
| description="Role of the message sender", | |
| examples=["user", "assistant"] | |
| ) | |
| content: str = Field(..., description="Content of the message") | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "role": "user", | |
| "content": "What is the capital of France?" | |
| } | |
| } | |
| ) | |
| class GenerationConfig(BaseModel): | |
| """Configuration for text generation.""" | |
| temperature: float = Field( | |
| 0.7, | |
| ge=0.0, | |
| le=2.0, | |
| description="Controls randomness in the output. Higher values (e.g., 0.8) make the output more random, lower values (e.g., 0.2) make it more focused and deterministic." | |
| ) | |
| max_new_tokens: int = Field( | |
| 100, | |
| ge=1, | |
| le=2048, | |
| description="Maximum number of tokens to generate" | |
| ) | |
| top_p: float = Field( | |
| 0.9, | |
| ge=0.0, | |
| le=1.0, | |
| description="Nucleus sampling parameter. Only tokens with cumulative probability < top_p are considered." | |
| ) | |
| top_k: int = Field( | |
| 50, | |
| ge=0, | |
| description="Only consider the top k tokens for text generation" | |
| ) | |
| strategy: str = Field( | |
| "default", | |
| description="Generation strategy to use", | |
| examples=["default", "majority_voting", "best_of_n", "beam_search", "dvts"] | |
| ) | |
| num_samples: int = Field( | |
| 5, | |
| ge=1, | |
| le=10, | |
| description="Number of samples to generate (used in majority_voting and best_of_n strategies)" | |
| ) | |
| class GenerationRequest(BaseModel): | |
| """Request model for text generation.""" | |
| context: Optional[str] = Field( | |
| None, | |
| description="Additional context to guide the generation", | |
| examples=["You are a helpful assistant skilled in Python programming"] | |
| ) | |
| messages: List[ChatMessage] = Field( | |
| ..., | |
| description="Chat history including the current message", | |
| min_items=1 | |
| ) | |
| config: Optional[GenerationConfig] = Field( | |
| None, | |
| description="Generation configuration parameters" | |
| ) | |
| stream: bool = Field( | |
| False, | |
| description="Whether to stream the response token by token" | |
| ) | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "context": "You are a helpful assistant", | |
| "messages": [ | |
| {"role": "user", "content": "What is the capital of France?"} | |
| ], | |
| "config": { | |
| "temperature": 0.7, | |
| "max_new_tokens": 100 | |
| }, | |
| "stream": False | |
| } | |
| } | |
| ) | |
| class GenerationResponse(BaseModel): | |
| """Response model for text generation.""" | |
| id: str = Field(..., description="Unique generation ID") | |
| content: str = Field(..., description="Generated text content") | |
| created_at: datetime = Field( | |
| default_factory=datetime.now, | |
| description="Timestamp of generation" | |
| ) | |
| # Model and cache management | |
| async def get_prm_model_path(): | |
| """Download and cache the PRM model.""" | |
| return await asyncio.to_thread( | |
| hf_hub_download, | |
| repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", | |
| filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" | |
| ) | |
| # Initialize generator globally | |
| generator = None | |
| async def lifespan(app: FastAPI): | |
| """Lifecycle management for the FastAPI application.""" | |
| # Startup: Initialize generator | |
| global generator | |
| try: | |
| prm_model_path = await get_prm_model_path() | |
| generator = LlamaGenerator( | |
| llama_model_name="meta-llama/Llama-3.2-1B-Instruct", | |
| prm_model_path=prm_model_path, | |
| default_generation_config=GenerationConfig( | |
| max_new_tokens=100, | |
| temperature=0.7 | |
| ) | |
| ) | |
| yield | |
| finally: | |
| # Shutdown: Clean up resources | |
| if generator: | |
| await asyncio.to_thread(generator.cleanup) | |
| # FastAPI application | |
| app = FastAPI( | |
| title="Inference Deluxe Service", | |
| description=""" | |
| A service for generating text using LLaMA models with various generation strategies. | |
| Generation Strategies: | |
| - default: Standard autoregressive generation | |
| - majority_voting: Generates multiple responses and selects the most common one | |
| - best_of_n: Generates multiple responses and selects the best based on a scoring metric | |
| - beam_search: Uses beam search for more coherent text generation | |
| - dvts: Dynamic vocabulary tree search for efficient generation | |
| """, | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def get_generator(): | |
| """Dependency to get the generator instance.""" | |
| if not generator: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Generator not initialized" | |
| ) | |
| return generator | |
| async def generate( | |
| request: GenerationRequest, | |
| generator: Any = Depends(get_generator) | |
| ): | |
| """ | |
| Generate a text response based on the provided context and chat history. | |
| """ | |
| try: | |
| chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] | |
| user_input = request.messages[-1].content | |
| # Extract or set defaults for additional arguments | |
| config = request.config or GenerationConfig() | |
| model_kwargs = { | |
| "temperature": config.temperature if hasattr(config, "temperature") else 0.7, | |
| "max_new_tokens": config.max_new_tokens if hasattr(config, "max_new_tokens") else 100, | |
| # Add other model kwargs as needed | |
| } | |
| # Explicitly pass additional required arguments | |
| response = await asyncio.to_thread( | |
| generator.generate_with_context, | |
| context=request.context or "", | |
| user_input=user_input, | |
| chat_history=chat_history, | |
| model_kwargs=model_kwargs, | |
| max_history_turns=config.max_history_turns if hasattr(config, "max_history_turns") else 3, | |
| strategy=config.strategy if hasattr(config, "strategy") else "default", | |
| num_samples=config.num_samples if hasattr(config, "num_samples") else 5, | |
| depth=config.depth if hasattr(config, "depth") else 3, | |
| breadth=config.breadth if hasattr(config, "breadth") else 2, | |
| ) | |
| return GenerationResponse( | |
| id=str(uuid.uuid4()), | |
| content=response | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_stream( | |
| websocket: WebSocket, | |
| generator: Any = Depends(get_generator) | |
| ): | |
| """ | |
| Stream generated text tokens over a WebSocket connection. | |
| The stream sends JSON messages with the following format: | |
| - During generation: {"token": "generated_token", "finished": false} | |
| - End of generation: {"token": "", "finished": true} | |
| - Error: {"error": "error_message"} | |
| """ | |
| await websocket.accept() | |
| try: | |
| while True: | |
| request_data = await websocket.receive_text() | |
| request = GenerationRequest.parse_raw(request_data) | |
| chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] | |
| user_input = request.messages[-1].content | |
| config = request.config or GenerationConfig() | |
| async for token in generator.generate_stream( | |
| prompt=generator.prompt_builder.format( | |
| context=request.context or "", | |
| user_input=user_input, | |
| chat_history=chat_history | |
| ), | |
| config=config | |
| ): | |
| await websocket.send_text(json.dumps({ | |
| "token": token, | |
| "finished": False | |
| })) | |
| await websocket.send_text(json.dumps({ | |
| "token": "", | |
| "finished": True | |
| })) | |
| except Exception as e: | |
| await websocket.send_text(json.dumps({ | |
| "error": str(e) | |
| })) | |
| finally: | |
| await websocket.close() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |