Spaces:
Runtime error
Runtime error
| from fastapi import APIRouter, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from typing import Optional | |
| import json | |
| from time import time | |
| import logging | |
| from .api import InferenceApi | |
| from .schemas import ( | |
| GenerateRequest, | |
| EmbeddingRequest, | |
| EmbeddingResponse, | |
| SystemStatusResponse, | |
| ValidationResponse, | |
| ChatCompletionRequest, | |
| ChatCompletionResponse, QueryExpansionResponse, QueryExpansionRequest, ChunkRerankResponse, ChunkRerankRequest | |
| ) | |
| router = APIRouter() | |
| logger = logging.getLogger(__name__) | |
| api = None | |
| config = None | |
| def init_router(inference_api: InferenceApi, conf): | |
| """Initialize router with an already setup API instance""" | |
| global api, config | |
| api = inference_api | |
| config = conf | |
| logger.info("Router initialized with Inference API instance") | |
| async def generate_text(request: GenerateRequest): | |
| """Generate text response from prompt""" | |
| logger.info(f"Received generation request for prompt: {request.prompt[:50]}...") | |
| try: | |
| response = await api.generate_response( | |
| prompt=request.prompt, | |
| system_message=request.system_message, | |
| max_new_tokens=request.max_new_tokens | |
| ) | |
| logger.info("Successfully generated response") | |
| return {"generated_text": response} | |
| except Exception as e: | |
| logger.error(f"Error in generate_text endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_stream(request: GenerateRequest): | |
| """Generate streaming text response from prompt""" | |
| logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...") | |
| try: | |
| return StreamingResponse( | |
| api.generate_stream( | |
| prompt=request.prompt, | |
| system_message=request.system_message, | |
| max_new_tokens=request.max_new_tokens | |
| ), | |
| media_type="text/event-stream" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in generate_stream endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| """OpenAI-compatible chat completion endpoint""" | |
| logger.info(f"Received chat completion request with {len(request.messages)} messages") | |
| try: | |
| # Extract the last user message, or combine messages if needed | |
| last_message = request.messages[-1].content | |
| if request.stream: | |
| # For streaming, we need to create a generator that yields OpenAI-compatible chunks | |
| async def generate_stream(): | |
| async for chunk in api.generate_stream( | |
| prompt=last_message, | |
| ): | |
| # Create a streaming response chunk in OpenAI format | |
| response_chunk = { | |
| "id": "chatcmpl-123", | |
| "object": "chat.completion.chunk", | |
| "created": int(time()), | |
| "model": request.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": { | |
| "content": chunk | |
| }, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {json.dumps(response_chunk)}\n\n" | |
| # Send the final chunk | |
| yield f"data: [DONE]\n\n" | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| # For non-streaming, generate the full response | |
| response_text = await api.generate_response( | |
| prompt=last_message, | |
| ) | |
| # Convert to OpenAI format | |
| return ChatCompletionResponse.from_response( | |
| content=response_text, | |
| model=request.model | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in chat completion endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def expand_query(request: QueryExpansionRequest): | |
| """Expand a query for RAG processing""" | |
| logger.info(f"Received query expansion request: {request.query[:50]}...") | |
| try: | |
| result = await api.expand_query( | |
| query=request.query, | |
| system_message=request.system_message | |
| ) | |
| logger.info("Successfully expanded query") | |
| return result | |
| except FileNotFoundError as e: | |
| logger.error(f"Template file not found: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Query expansion template not found") | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Invalid JSON response from LLM: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Invalid response format from LLM") | |
| except Exception as e: | |
| logger.error(f"Error in expand_query endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def rerank_chunks(request: ChunkRerankRequest): | |
| """Rerank chunks based on their relevance to the query""" | |
| logger.info(f"Received reranking request for query: {request.query[:50]}...") | |
| try: | |
| result = await api.rerank_chunks( | |
| query=request.query, | |
| chunks=request.chunks, | |
| system_message=request.system_message | |
| ) | |
| logger.info(f"Successfully reranked {len(request.chunks)} chunks") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error in rerank_chunks endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_embedding(request: EmbeddingRequest): | |
| """Generate embedding vector from text""" | |
| logger.info(f"Received embedding request for text: {request.text[:50]}...") | |
| try: | |
| embedding = await api.generate_embedding(request.text) | |
| logger.info(f"Successfully generated embedding of dimension {len(embedding)}") | |
| return EmbeddingResponse( | |
| embedding=embedding, | |
| dimension=len(embedding) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in generate_embedding endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def check_system(): | |
| """Get system status from LLM Server""" | |
| try: | |
| return await api.check_system_status() | |
| except Exception as e: | |
| logger.error(f"Error checking system status: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def validate_system(): | |
| """Get system validation status from LLM Server""" | |
| try: | |
| return await api.validate_system() | |
| except Exception as e: | |
| logger.error(f"Error validating system: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def initialize_model(model_name: Optional[str] = None): | |
| """Initialize a model for use""" | |
| try: | |
| return await api.initialize_model(model_name) | |
| except Exception as e: | |
| logger.error(f"Error initializing model: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def initialize_embedding_model(model_name: Optional[str] = None): | |
| """Initialize a model specifically for embeddings""" | |
| try: | |
| return await api.initialize_embedding_model(model_name) | |
| except Exception as e: | |
| logger.error(f"Error initializing embedding model: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def download_model(model_name: Optional[str] = None): | |
| """Download model files to local storage""" | |
| try: | |
| # Use model name from config if none provided | |
| model_to_download = model_name or config["model"]["defaults"]["model_name"] | |
| logger.info(f"Received request to download model: {model_to_download}") | |
| result = await api.download_model(model_to_download) | |
| logger.info(f"Successfully downloaded model: {model_to_download}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error downloading model: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def shutdown_event(): | |
| """Clean up resources on shutdown""" | |
| if api: | |
| await api.cleanup() |