| |
|
| | from langfuse import Langfuse |
| | from langfuse.decorators import observe, langfuse_context |
| |
|
| | from config.config import settings |
| | from services.llama_generator import LlamaGenerator |
| | import os |
| |
|
| | |
| | os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-04d2302a-aa5c-4870-9703-58ab64c3bcae" |
| | os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-d34ea200-feec-428e-a621-784fce93a5af" |
| | os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" |
| |
|
| | try: |
| | langfuse = Langfuse() |
| | except Exception as e: |
| | print("Langfuse Offline") |
| | |
| |
|
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | from fastapi import FastAPI, HTTPException, BackgroundTasks, WebSocket, Depends |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import StreamingResponse |
| | from pydantic import BaseModel, Field, ConfigDict |
| | from typing import List, Optional, Dict, Any, AsyncGenerator |
| | import asyncio |
| | import uuid |
| | from datetime import datetime |
| | import json |
| | from huggingface_hub import hf_hub_download |
| | from contextlib import asynccontextmanager |
| |
|
| | |
| |
|
| | class ChatMessage(BaseModel): |
| | """A single message in the chat history.""" |
| | role: str = Field( |
| | ..., |
| | description="Role of the message sender", |
| | examples=["user", "assistant"] |
| | ) |
| | content: str = Field(..., description="Content of the message") |
| | |
| | model_config = ConfigDict( |
| | json_schema_extra={ |
| | "example": { |
| | "role": "user", |
| | "content": "What is the capital of France?" |
| | } |
| | } |
| | ) |
| | |
| | class GenerationConfig(BaseModel): |
| | """Configuration for text generation.""" |
| | temperature: float = Field( |
| | 0.7, |
| | ge=0.0, |
| | le=2.0, |
| | description="Controls randomness in the output. Higher values (e.g., 0.8) make the output more random, lower values (e.g., 0.2) make it more focused and deterministic." |
| | ) |
| | max_new_tokens: int = Field( |
| | 100, |
| | ge=1, |
| | le=2048, |
| | description="Maximum number of tokens to generate" |
| | ) |
| | top_p: float = Field( |
| | 0.9, |
| | ge=0.0, |
| | le=1.0, |
| | description="Nucleus sampling parameter. Only tokens with cumulative probability < top_p are considered." |
| | ) |
| | top_k: int = Field( |
| | 50, |
| | ge=0, |
| | description="Only consider the top k tokens for text generation" |
| | ) |
| | strategy: str = Field( |
| | "default", |
| | description="Generation strategy to use", |
| | examples=["default", "majority_voting", "best_of_n", "beam_search", "dvts"] |
| | ) |
| | num_samples: int = Field( |
| | 5, |
| | ge=1, |
| | le=10, |
| | description="Number of samples to generate (used in majority_voting and best_of_n strategies)" |
| | ) |
| |
|
| | class GenerationRequest(BaseModel): |
| | """Request model for text generation.""" |
| | context: Optional[str] = Field( |
| | None, |
| | description="Additional context to guide the generation", |
| | examples=["You are a helpful assistant skilled in Python programming"] |
| | ) |
| | messages: List[ChatMessage] = Field( |
| | ..., |
| | description="Chat history including the current message", |
| | min_items=1 |
| | ) |
| | config: Optional[GenerationConfig] = Field( |
| | None, |
| | description="Generation configuration parameters" |
| | ) |
| | stream: bool = Field( |
| | False, |
| | description="Whether to stream the response token by token" |
| | ) |
| | |
| | model_config = ConfigDict( |
| | json_schema_extra={ |
| | "example": { |
| | "context": "You are a helpful assistant", |
| | "messages": [ |
| | {"role": "user", "content": "What is the capital of France?"} |
| | ], |
| | "config": { |
| | "temperature": 0.7, |
| | "max_new_tokens": 100 |
| | }, |
| | "stream": False |
| | } |
| | } |
| | ) |
| |
|
| | class GenerationResponse(BaseModel): |
| | """Response model for text generation.""" |
| | id: str = Field(..., description="Unique generation ID") |
| | content: str = Field(..., description="Generated text content") |
| | created_at: datetime = Field( |
| | default_factory=datetime.now, |
| | description="Timestamp of generation" |
| | ) |
| | |
| | |
| | |
| | async def get_prm_model_path(): |
| | """Download and cache the PRM model.""" |
| | return await asyncio.to_thread( |
| | hf_hub_download, |
| | repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF", |
| | filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf" |
| | ) |
| |
|
| | |
| | generator = None |
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | """Lifecycle management for the FastAPI application.""" |
| | |
| | global generator |
| | try: |
| | prm_model_path = await get_prm_model_path() |
| | generator = LlamaGenerator( |
| | llama_model_name="meta-llama/Llama-3.2-1B-Instruct", |
| | prm_model_path=prm_model_path, |
| | default_generation_config=GenerationConfig( |
| | max_new_tokens=100, |
| | temperature=0.7 |
| | ) |
| | ) |
| | yield |
| | finally: |
| | |
| | if generator: |
| | await asyncio.to_thread(generator.cleanup) |
| |
|
| | |
| | app = FastAPI( |
| | title="Inference Deluxe Service", |
| | description=""" |
| | A service for generating text using LLaMA models with various generation strategies. |
| | |
| | Generation Strategies: |
| | - default: Standard autoregressive generation |
| | - majority_voting: Generates multiple responses and selects the most common one |
| | - best_of_n: Generates multiple responses and selects the best based on a scoring metric |
| | - beam_search: Uses beam search for more coherent text generation |
| | - dvts: Dynamic vocabulary tree search for efficient generation |
| | """, |
| | version="1.0.0", |
| | lifespan=lifespan |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | async def get_generator(): |
| | """Dependency to get the generator instance.""" |
| | if not generator: |
| | raise HTTPException( |
| | status_code=503, |
| | detail="Generator not initialized" |
| | ) |
| | return generator |
| |
|
| | @app.post( |
| | "/generate", |
| | response_model=GenerationResponse, |
| | tags=["generation"], |
| | summary="Generate text response", |
| | response_description="Generated text with unique identifier" |
| | ) |
| | async def generate( |
| | request: GenerationRequest, |
| | generator: Any = Depends(get_generator) |
| | ): |
| | """ |
| | Generate a text response based on the provided context and chat history. |
| | """ |
| | try: |
| | chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] |
| | user_input = request.messages[-1].content |
| |
|
| | |
| | config = request.config or GenerationConfig() |
| | model_kwargs = { |
| | "temperature": config.temperature if hasattr(config, "temperature") else 0.7, |
| | "max_new_tokens": config.max_new_tokens if hasattr(config, "max_new_tokens") else 100, |
| | |
| | } |
| | |
| | |
| | response = await asyncio.to_thread( |
| | generator.generate_with_context, |
| | context=request.context or "", |
| | user_input=user_input, |
| | chat_history=chat_history, |
| | model_kwargs=model_kwargs, |
| | max_history_turns=config.max_history_turns if hasattr(config, "max_history_turns") else 3, |
| | strategy=config.strategy if hasattr(config, "strategy") else "default", |
| | num_samples=config.num_samples if hasattr(config, "num_samples") else 5, |
| | depth=config.depth if hasattr(config, "depth") else 3, |
| | breadth=config.breadth if hasattr(config, "breadth") else 2, |
| | ) |
| |
|
| | return GenerationResponse( |
| | id=str(uuid.uuid4()), |
| | content=response |
| | ) |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.websocket("/generate/stream") |
| | async def generate_stream( |
| | websocket: WebSocket, |
| | generator: Any = Depends(get_generator) |
| | ): |
| | """ |
| | Stream generated text tokens over a WebSocket connection. |
| | |
| | The stream sends JSON messages with the following format: |
| | - During generation: {"token": "generated_token", "finished": false} |
| | - End of generation: {"token": "", "finished": true} |
| | - Error: {"error": "error_message"} |
| | """ |
| | await websocket.accept() |
| | |
| | try: |
| | while True: |
| | request_data = await websocket.receive_text() |
| | request = GenerationRequest.parse_raw(request_data) |
| | |
| | chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]] |
| | user_input = request.messages[-1].content |
| | |
| | config = request.config or GenerationConfig() |
| | |
| | async for token in generator.generate_stream( |
| | prompt=generator.prompt_builder.format( |
| | context=request.context or "", |
| | user_input=user_input, |
| | chat_history=chat_history |
| | ), |
| | config=config |
| | ): |
| | await websocket.send_text(json.dumps({ |
| | "token": token, |
| | "finished": False |
| | })) |
| | |
| | await websocket.send_text(json.dumps({ |
| | "token": "", |
| | "finished": True |
| | })) |
| | |
| | except Exception as e: |
| | await websocket.send_text(json.dumps({ |
| | "error": str(e) |
| | })) |
| | finally: |
| | await websocket.close() |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|
| |
|