|
|
| from langfuse import Langfuse |
| from langfuse.decorators import observe, langfuse_context |
|
|
| from config.config import settings |
| from services.llama_generator import llama_generator |
| import os |
|
|
| |
| os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-9f2c32d2-266f-421d-9b87-51377f0a268c" |
| os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-229e10c5-6210-4a4b-a432-0f17bc66e56c" |
| os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" |
|
|
| try: |
| langfuse = Langfuse() |
| except Exception as e: |
| print("Langfuse Offline") |
| |
|
|
|
|
|
|
|
|
| |
| |
| |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, WebSocket, Depends |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel, Field, ConfigDict |
| from typing import List, Optional, Dict, Any, AsyncGenerator |
| import asyncio |
| import uuid |
| from datetime import datetime |
| import json |
| from huggingface_hub import hf_hub_download |
| from contextlib import asynccontextmanager |
|
|
| |
|
|
| class ChatMessage(BaseModel): |
| """A single message in the chat history.""" |
| role: str = Field( |
| ..., |
| description="Role of the message sender", |
| examples=["user", "assistant"] |
| ) |
| content: str = Field(..., description="Content of the message") |
| |
| model_config = ConfigDict( |
| json_schema_extra={ |
| "example": { |
| "role": "user", |
| "content": "What is the capital of France?" |
| } |
| } |
| ) |
| |
| class GenerationConfig(BaseModel): |
| """Configuration for text generation.""" |
| temperature: float = Field( |
| 0.7, |
| ge=0.0, |
| le=2.0, |
| description="Controls randomness in the output. Higher values (e.g., 0.8) make the output more random, lower values (e.g., 0.2) make it more focused and deterministic." |
| ) |
| max_new_tokens: int = Field( |
| 100, |
| ge=1, |
| le=2048, |
| description="Maximum number of tokens to generate" |
| ) |
| top_p: float = Field( |
| 0.9, |
| ge=0.0, |
| le=1.0, |
| description="Nucleus sampling parameter. Only tokens with cumulative probability < top_p are considered." |
| ) |
| top_k: int = Field( |
| 50, |
| ge=0, |
| description="Only consider the top k tokens for text generation" |
| ) |
| strategy: str = Field( |
| "default", |
| description="Generation strategy to use", |
| examples=["default", "majority_voting", "best_of_n", "beam_search", "dvts"] |
| ) |
| num_samples: int = Field( |
| 5, |
| ge=1, |
| le=10, |
| description="Number of samples to generate (used in majority_voting and best_of_n strategies)" |
| ) |
|
|
| class GenerationRequest(BaseModel): |
| """Request model for text generation.""" |
| context: Optional[str] = Field( |
| None, |
| description="Additional context to guide the generation", |
| examples=["You are a helpful assistant skilled in Python programming"] |
| ) |
| messages: List[ChatMessage] = Field( |
| ..., |
| description="Chat history including the current message", |
| min_items=1 |
| ) |
| config: Optional[GenerationConfig] = Field( |
| None, |
| description="Generation configuration parameters" |
| ) |
| stream: bool = Field( |
| False, |
| description="Whether to stream the response token by token" |
| ) |
| |
| model_config = ConfigDict( |
| json_schema_extra={ |
| "example": { |
| "context": "You are a helpful assistant", |
| "messages": [ |
| {"role": "user", "content": "What is the capital of France?"} |
| ], |
| "config": { |
| "temperature": 0.7, |
| "max_new_tokens": 100 |
| }, |
| "stream": False |
| } |
| } |
| ) |
|
|
| class GenerationResponse(BaseModel): |
| """Response model for text generation.""" |
| id: str = Field(..., description="Unique generation ID") |
| content: str = Field(..., description="Generated text content") |
| created_at: datetime = Field( |
| default_factory=datetime.now, |
| description="Timestamp of generation" |
| ) |
| |
| |
| |
| async def get_prm_model_path(): |
| """Download and cache the PRM model.""" |
| return await asyncio.to_thread( |
| hf_hub_download, |
| repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", |
| filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" |
| ) |
|
|
| |
| generator = None |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Lifecycle management for the FastAPI application.""" |
| |
| global generator |
| try: |
| prm_model_path = await get_prm_model_path() |
| generator = LlamaGenerator( |
| llama_model_name="meta-llama/Llama-3.2-1B-Instruct", |
| prm_model_path=prm_model_path, |
| default_generation_config=GenerationConfig( |
| max_new_tokens=100, |
| temperature=0.7 |
| ) |
| ) |
| yield |
| finally: |
| |
| if generator: |
| await asyncio.to_thread(generator.cleanup) |
|
|
| |
| app = FastAPI( |
| title="Inference Deluxe Service", |
| description=""" |
| A service for generating text using LLaMA models with various generation strategies. |
| |
| Generation Strategies: |
| - default: Standard autoregressive generation |
| - majority_voting: Generates multiple responses and selects the most common one |
| - best_of_n: Generates multiple responses and selects the best based on a scoring metric |
| - beam_search: Uses beam search for more coherent text generation |
| - dvts: Dynamic vocabulary tree search for efficient generation |
| """, |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| async def get_generator(): |
| """Dependency to get the generator instance.""" |
| if not generator: |
| raise HTTPException( |
| status_code=503, |
| detail="Generator not initialized" |
| ) |
| return generator |
|
|
| @app.post( |
| "/generate", |
| response_model=GenerationResponse, |
| tags=["generation"], |
| summary="Generate text response", |
| response_description="Generated text with unique identifier" |
| ) |
| async def generate( |
| request: GenerationRequest, |
| generator: Any = Depends(get_generator) |
| ): |
| """ |
| Generate a text response based on the provided context and chat history. |
| """ |
| try: |
| chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] |
| user_input = request.messages[-1].content |
|
|
| |
| config = request.config or GenerationConfig() |
| model_kwargs = { |
| "temperature": config.temperature if hasattr(config, "temperature") else 0.7, |
| "max_new_tokens": config.max_new_tokens if hasattr(config, "max_new_tokens") else 100, |
| |
| } |
| |
| |
| response = await asyncio.to_thread( |
| generator.generate_with_context, |
| context=request.context or "", |
| user_input=user_input, |
| chat_history=chat_history, |
| model_kwargs=model_kwargs, |
| max_history_turns=config.max_history_turns if hasattr(config, "max_history_turns") else 3, |
| strategy=config.strategy if hasattr(config, "strategy") else "default", |
| num_samples=config.num_samples if hasattr(config, "num_samples") else 5, |
| depth=config.depth if hasattr(config, "depth") else 3, |
| breadth=config.breadth if hasattr(config, "breadth") else 2, |
| ) |
|
|
| return GenerationResponse( |
| id=str(uuid.uuid4()), |
| content=response |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.websocket("/generate/stream") |
| async def generate_stream( |
| websocket: WebSocket, |
| generator: Any = Depends(get_generator) |
| ): |
| """ |
| Stream generated text tokens over a WebSocket connection. |
| |
| The stream sends JSON messages with the following format: |
| - During generation: {"token": "generated_token", "finished": false} |
| - End of generation: {"token": "", "finished": true} |
| - Error: {"error": "error_message"} |
| """ |
| await websocket.accept() |
| |
| try: |
| while True: |
| request_data = await websocket.receive_text() |
| request = GenerationRequest.parse_raw(request_data) |
| |
| chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] |
| user_input = request.messages[-1].content |
| |
| config = request.config or GenerationConfig() |
| |
| async for token in generator.generate_stream( |
| prompt=generator.prompt_builder.format( |
| context=request.context or "", |
| user_input=user_input, |
| chat_history=chat_history |
| ), |
| config=config |
| ): |
| await websocket.send_text(json.dumps({ |
| "token": token, |
| "finished": False |
| })) |
| |
| await websocket.send_text(json.dumps({ |
| "token": "", |
| "finished": True |
| })) |
| |
| except Exception as e: |
| await websocket.send_text(json.dumps({ |
| "error": str(e) |
| })) |
| finally: |
| await websocket.close() |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|