Spaces:
Paused
Paused
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict, Any, Union | |
| import torch | |
| import logging | |
| from pathlib import Path | |
| from litgpt.api import LLM | |
| import os | |
| import uvicorn | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="LLM Engine Service") | |
| # Global variable to store the LLM instance | |
| llm_instance = None | |
| class InitializeRequest(BaseModel): | |
| """ | |
| Configuration for model initialization including model path | |
| """ | |
| mode: str = "cpu" | |
| precision: Optional[str] = None | |
| quantize: Optional[str] = None | |
| gpu_count: Union[str, int] = "auto" | |
| model_path: str | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| max_new_tokens: int = 50 | |
| temperature: float = 1.0 | |
| top_k: Optional[int] = None | |
| top_p: float = 1.0 | |
| return_as_token_ids: bool = False | |
| stream: bool = False | |
| async def initialize_model(request: InitializeRequest): | |
| """ | |
| Initialize the LLM model with specified configuration. | |
| """ | |
| global llm_instance | |
| try: | |
| # Get the project root directory (where main.py is located) | |
| project_root = Path(__file__).parent | |
| checkpoints_dir = project_root / "checkpoints" | |
| # For LitGPT downloaded models, path includes organization | |
| if "/" in request.model_path: | |
| # e.g., "mistralai/Mistral-7B-Instruct-v0.3" | |
| org, model_name = request.model_path.split("/") | |
| model_path = str(checkpoints_dir / org / model_name) | |
| else: | |
| # Fallback for direct model paths | |
| model_path = str(checkpoints_dir / request.model_path) | |
| logger.info(f"Using model path: {model_path}") | |
| # Load the model | |
| llm_instance = LLM.load( | |
| model=model_path, | |
| distribute=None if request.precision or request.quantize else "auto" | |
| ) | |
| # If manual distribution is needed | |
| if request.precision or request.quantize: | |
| llm_instance.distribute( | |
| accelerator="cuda" if request.mode == "gpu" else "cpu", | |
| devices=request.gpu_count, | |
| precision=request.precision, | |
| quantize=request.quantize | |
| ) | |
| logger.info( | |
| f"Model initialized successfully with config:\n" | |
| f"Mode: {request.mode}\n" | |
| f"Precision: {request.precision}\n" | |
| f"Quantize: {request.quantize}\n" | |
| f"GPU Count: {request.gpu_count}\n" | |
| f"Model Path: {model_path}\n" | |
| f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, " | |
| f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved" | |
| ) | |
| return {"success": True, "message": "Model initialized successfully"} | |
| except Exception as e: | |
| logger.error(f"Error initializing model: {str(e)}") | |
| # Print detailed memory statistics on failure | |
| logger.error(f"GPU Memory Stats:\n" | |
| f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n" | |
| f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n" | |
| f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB") | |
| raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}") | |
| async def generate(request: GenerateRequest): | |
| """ | |
| Generate text using the initialized model. | |
| """ | |
| global llm_instance | |
| if llm_instance is None: | |
| raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.") | |
| try: | |
| if request.stream: | |
| # For streaming responses, we need to handle differently | |
| # This is a placeholder as the actual streaming implementation | |
| # would need to use StreamingResponse from FastAPI | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Streaming is not currently supported through the API" | |
| ) | |
| generated_text = llm_instance.generate( | |
| prompt=request.prompt, | |
| max_new_tokens=request.max_new_tokens, | |
| temperature=request.temperature, | |
| top_k=request.top_k, | |
| top_p=request.top_p, | |
| return_as_token_ids=request.return_as_token_ids, | |
| stream=False # Force stream to False for now | |
| ) | |
| response = { | |
| "generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(), | |
| "metadata": { | |
| "prompt": request.prompt, | |
| "max_new_tokens": request.max_new_tokens, | |
| "temperature": request.temperature, | |
| "top_k": request.top_k, | |
| "top_p": request.top_p | |
| } | |
| } | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating text: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") | |
| async def health_check(): | |
| """ | |
| Check if the service is running and model is loaded. | |
| """ | |
| global llm_instance | |
| status = { | |
| "status": "healthy", | |
| "model_loaded": llm_instance is not None, | |
| } | |
| if llm_instance is not None: | |
| status["model_info"] = { | |
| "model_path": llm_instance.config.name, | |
| "device": str(next(llm_instance.model.parameters()).device) | |
| } | |
| return status | |
| def main(): | |
| # Load environment variables or configuration here | |
| host = os.getenv("LLM_ENGINE_HOST", "0.0.0.0") | |
| port = int(os.getenv("LLM_ENGINE_PORT", "8001")) | |
| # Start the server | |
| uvicorn.run( | |
| app, | |
| host=host, | |
| port=port, | |
| log_level="info", | |
| reload=False | |
| ) | |
| if __name__ == "__main__": | |
| main() |