Spaces:
Runtime error
Runtime error
| from fastapi import APIRouter, HTTPException | |
| from typing import Optional | |
| from .api import InferenceApi | |
| from .schemas import ( | |
| GenerateRequest, | |
| EmbeddingRequest, | |
| EmbeddingResponse, | |
| SystemStatusResponse, | |
| ValidationResponse | |
| ) | |
| import logging | |
| router = APIRouter() | |
| logger = logging.getLogger(__name__) | |
| api = None | |
| def init_router(config: dict): | |
| """Initialize router with config and Inference API instance""" | |
| global api | |
| api = InferenceApi(config) | |
| logger.info("Router initialized with Inference API instance") | |
| async def generate_text(request: GenerateRequest): | |
| """Generate text response from prompt""" | |
| logger.info(f"Received generation request for prompt: {request.prompt[:50]}...") | |
| try: | |
| response = await api.generate_response( | |
| prompt=request.prompt, | |
| system_message=request.system_message, | |
| max_new_tokens=request.max_new_tokens | |
| ) | |
| logger.info("Successfully generated response") | |
| return {"generated_text": response} | |
| except Exception as e: | |
| logger.error(f"Error in generate_text endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_stream(request: GenerateRequest): | |
| """Generate streaming text response from prompt""" | |
| logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...") | |
| try: | |
| return api.generate_stream( | |
| prompt=request.prompt, | |
| system_message=request.system_message, | |
| max_new_tokens=request.max_new_tokens | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in generate_stream endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_embedding(request: EmbeddingRequest): | |
| """Generate embedding vector from text""" | |
| logger.info(f"Received embedding request for text: {request.text[:50]}...") | |
| try: | |
| embedding = await api.generate_embedding(request.text) | |
| logger.info(f"Successfully generated embedding of dimension {len(embedding)}") | |
| return EmbeddingResponse( | |
| embedding=embedding, | |
| dimension=len(embedding) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in generate_embedding endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def check_system(): | |
| """Get system status from LLM Server""" | |
| try: | |
| return await api.check_system_status() | |
| except Exception as e: | |
| logger.error(f"Error checking system status: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def validate_system(): | |
| """Get system validation status from LLM Server""" | |
| try: | |
| return await api.validate_system() | |
| except Exception as e: | |
| logger.error(f"Error validating system: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def initialize_model(model_name: Optional[str] = None): | |
| """Initialize a model for use""" | |
| try: | |
| return await api.initialize_model(model_name) | |
| except Exception as e: | |
| logger.error(f"Error initializing model: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def initialize_embedding_model(model_name: Optional[str] = None): | |
| """Initialize a model specifically for embeddings""" | |
| try: | |
| return await api.initialize_embedding_model(model_name) | |
| except Exception as e: | |
| logger.error(f"Error initializing embedding model: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def shutdown_event(): | |
| """Clean up resources on shutdown""" | |
| if api: | |
| await api.close() |