Spaces:
Runtime error
Runtime error
| import json | |
| from pathlib import Path | |
| import httpx | |
| from typing import Optional, AsyncIterator, Dict, Any, Iterator, List, Callable | |
| import logging | |
| import asyncio | |
| from litserve import LitAPI | |
| from pydantic import BaseModel | |
| from .utils import extract_json | |
| class GenerationResponse(BaseModel): | |
| generated_text: str | |
| class InferenceApi(LitAPI): | |
| def __init__(self, config: Dict[str, Any]): | |
| """Initialize the Inference API with configuration.""" | |
| super().__init__() | |
| self.logger = logging.getLogger(__name__) | |
| self.logger.info("Initializing Inference API") | |
| self._device = None | |
| self.stream = False | |
| self.config = config | |
| self.llm_config = config.get('llm_server', {}) | |
| def setup(self, device: Optional[str] = None): | |
| """Synchronous setup method required by LitAPI""" | |
| self._device = device | |
| self.logger.info(f"Inference API setup completed on device: {device}") | |
| return self # It's common for setup methods to return self for chaining | |
| async def _get_client(self): | |
| """Get or create HTTP client as needed""" | |
| host = self.llm_config.get('host', 'localhost') | |
| port = self.llm_config.get('port', 8002) | |
| # Construct base URL, omitting port for HF spaces | |
| if 'hf.space' in host: | |
| base_url = f"https://{host}" | |
| else: | |
| base_url = f"http://{host}:{port}" | |
| return httpx.AsyncClient( | |
| base_url=base_url, | |
| timeout=float(self.llm_config.get('timeout', 60.0)) | |
| ) | |
| def _get_endpoint(self, endpoint_name: str) -> str: | |
| """Get full endpoint path including prefix""" | |
| endpoints = self.llm_config.get('endpoints', {}) | |
| api_prefix = self.llm_config.get('api_prefix', '') | |
| endpoint = endpoints.get(endpoint_name, '') | |
| return f"{api_prefix}{endpoint}" | |
| async def _make_request( | |
| self, | |
| method: str, | |
| endpoint: str, | |
| *, | |
| params: Optional[Dict[str, Any]] = None, | |
| json: Optional[Dict[str, Any]] = None, | |
| stream: bool = False | |
| ) -> Any: | |
| """Make an authenticated request to the LLM Server.""" | |
| base_url = self.llm_config.get('host', 'http://localhost:8001') | |
| full_endpoint = f"{base_url.rstrip('/')}/{self._get_endpoint(endpoint).lstrip('/')}" | |
| try: | |
| self.logger.info(f"Making {method} request to: {full_endpoint}") | |
| # Create client outside the with block for streaming | |
| client = await self._get_client() | |
| if stream: | |
| # For streaming, return both client and response context managers | |
| return client, client.stream( | |
| method, | |
| self._get_endpoint(endpoint), | |
| params=params, | |
| json=json | |
| ) | |
| else: | |
| # For non-streaming, use context manager | |
| async with client as c: | |
| response = await c.request( | |
| method, | |
| self._get_endpoint(endpoint), | |
| params=params, | |
| json=json | |
| ) | |
| response.raise_for_status() | |
| return response | |
| except Exception as e: | |
| self.logger.error(f"Error in request to {full_endpoint}: {str(e)}") | |
| raise | |
| def predict(self, x: str, **kwargs) -> Iterator[str]: | |
| """Non-async prediction method that yields results.""" | |
| loop = asyncio.get_event_loop() | |
| async def async_gen(): | |
| async for item in self._async_predict(x, **kwargs): | |
| yield item | |
| gen = async_gen() | |
| while True: | |
| try: | |
| yield loop.run_until_complete(gen.__anext__()) | |
| except StopAsyncIteration: | |
| break | |
| async def _async_predict(self, x: str, **kwargs) -> AsyncIterator[str]: | |
| """Internal async prediction method.""" | |
| if self.stream: | |
| async for chunk in self.generate_stream(x, **kwargs): | |
| yield chunk | |
| else: | |
| response = await self.generate_response(x, **kwargs) | |
| yield response | |
| async def generate_response( | |
| self, | |
| prompt: str, | |
| system_message: Optional[str] = None, | |
| max_new_tokens: Optional[int] = None | |
| ) -> str: | |
| """Generate a complete response by forwarding the request to the LLM Server.""" | |
| self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") | |
| try: | |
| response = await self._make_request( | |
| "POST", | |
| "generate", | |
| json={ | |
| "prompt": prompt, | |
| "system_message": system_message, | |
| "max_new_tokens": max_new_tokens | |
| } | |
| ) | |
| data = response.json() | |
| return data["generated_text"] | |
| except Exception as e: | |
| self.logger.error(f"Error in generate_response: {str(e)}") | |
| raise | |
| async def structured_llm_query( | |
| self, | |
| template_name: str, | |
| input_text: str, | |
| additional_context: Optional[Dict[str, Any]] = None, | |
| pre_hooks: Optional[List[Callable]] = None, | |
| post_hooks: Optional[List[Callable]] = None | |
| ) -> Dict[str, Any]: | |
| """Execute a structured LLM query using a template.""" | |
| template_path = Path(__file__).parent / "prompt_templates" / f"{template_name}.json" | |
| try: | |
| # Load and parse template | |
| with open(template_path) as f: | |
| template = json.load(f) | |
| # Apply pre-processing hooks | |
| processed_input = input_text | |
| if pre_hooks: | |
| for hook in pre_hooks: | |
| processed_input = hook(processed_input) | |
| # Format the prompt with the context | |
| context = {"input_text": processed_input} | |
| if additional_context: | |
| context.update(additional_context) | |
| prompt = template["prompt_template"].format(**context) | |
| # Make the request to the LLM | |
| response = await self._make_request( | |
| "POST", | |
| "generate", | |
| json={ | |
| "prompt": prompt, | |
| "system_message": template.get("system_message"), | |
| "max_new_tokens": 1000 | |
| } | |
| ) | |
| # Extract JSON from response | |
| data = response.json() | |
| result = extract_json(data["generated_text"]) | |
| # Apply any additional post-processing hooks | |
| if post_hooks: | |
| for hook in post_hooks: | |
| result = hook(result) | |
| return result | |
| except FileNotFoundError: | |
| raise ValueError(f"Template {template_name} not found") | |
| except Exception as e: | |
| self.logger.error(f"Error in structured_llm_query: {str(e)}") | |
| raise | |
| async def expand_query( | |
| self, | |
| query: str, | |
| system_message: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """Expand a query for RAG processing.""" | |
| return await self.structured_llm_query( | |
| template_name="query_expansion", | |
| input_text=query, | |
| additional_context={"system_message": system_message} if system_message else None | |
| ) | |
| async def rerank_chunks( | |
| self, | |
| query: str, | |
| chunks: List[str], | |
| system_message: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """Rerank text chunks based on their relevance to the query.""" | |
| # Format chunks as numbered list for better LLM processing | |
| formatted_chunks = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks)) | |
| return await self.structured_llm_query( | |
| template_name="chunk_rerank", | |
| input_text=query, | |
| additional_context={ | |
| "chunks": formatted_chunks, | |
| "system_message": system_message | |
| } | |
| ) | |
| async def generate_stream( | |
| self, | |
| prompt: str, | |
| system_message: Optional[str] = None, | |
| max_new_tokens: Optional[int] = None | |
| ) -> AsyncIterator[str]: | |
| """Generate a streaming response by forwarding the request to the LLM Server.""" | |
| self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") | |
| try: | |
| client, stream_cm = await self._make_request( | |
| "POST", | |
| "generate_stream", | |
| json={ | |
| "prompt": prompt, | |
| "system_message": system_message, | |
| "max_new_tokens": max_new_tokens | |
| }, | |
| stream=True | |
| ) | |
| async with client: | |
| async with stream_cm as response: | |
| async for chunk in response.aiter_text(): | |
| yield chunk | |
| except Exception as e: | |
| self.logger.error(f"Error in generate_stream: {str(e)}") | |
| raise | |
| async def generate_embedding(self, text: str) -> List[float]: | |
| """Generate embedding vector from input text.""" | |
| self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...") | |
| try: | |
| response = await self._make_request( | |
| "POST", | |
| "embedding", | |
| json={"text": text} | |
| ) | |
| data = response.json() | |
| return data["embedding"] | |
| except Exception as e: | |
| self.logger.error(f"Error in generate_embedding: {str(e)}") | |
| raise | |
| async def check_system_status(self) -> Dict[str, Any]: | |
| """Check system status of the LLM Server.""" | |
| self.logger.debug("Checking system status...") | |
| try: | |
| response = await self._make_request( | |
| "GET", | |
| "system_status" | |
| ) | |
| return response.json() | |
| except Exception as e: | |
| self.logger.error(f"Error in check_system_status: {str(e)}") | |
| raise | |
| async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]: | |
| """Download model files from the LLM Server.""" | |
| self.logger.debug(f"Forwarding model download request for: {model_name or 'default model'}") | |
| try: | |
| response = await self._make_request( | |
| "POST", | |
| "model_download", | |
| params={"model_name": model_name} if model_name else None | |
| ) | |
| return response.json() | |
| except Exception as e: | |
| self.logger.error(f"Error in download_model: {str(e)}") | |
| raise | |
| async def validate_system(self) -> Dict[str, Any]: | |
| """Validate system configuration and setup.""" | |
| self.logger.debug("Validating system configuration...") | |
| try: | |
| response = await self._make_request( | |
| "GET", | |
| "system_validate" | |
| ) | |
| return response.json() | |
| except Exception as e: | |
| self.logger.error(f"Error in validate_system: {str(e)}") | |
| raise | |
| async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: | |
| """Initialize specified model or default model.""" | |
| self.logger.debug(f"Initializing model: {model_name or 'default'}") | |
| try: | |
| response = await self._make_request( | |
| "POST", | |
| "model_initialize", | |
| params={"model_name": model_name} if model_name else None | |
| ) | |
| return response.json() | |
| except Exception as e: | |
| self.logger.error(f"Error in initialize_model: {str(e)}") | |
| raise | |
| async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: | |
| """Initialize embedding model.""" | |
| self.logger.debug(f"Initializing embedding model: {model_name or 'default'}") | |
| try: | |
| response = await self._make_request( | |
| "POST", | |
| "model_initialize_embedding", | |
| json={"model_name": model_name} if model_name else {} | |
| ) | |
| return response.json() | |
| except Exception as e: | |
| self.logger.error(f"Error in initialize_embedding_model: {str(e)}") | |
| raise | |
| def decode_request(self, request: Any, **kwargs) -> str: | |
| """Convert the request payload to input format.""" | |
| if isinstance(request, dict) and "prompt" in request: | |
| return request["prompt"] | |
| return request | |
| def encode_response(self, output: Iterator[str], **kwargs) -> Dict[str, Any]: | |
| """Convert the model output to a response payload.""" | |
| if self.stream: | |
| return {"generated_text": output} | |
| try: | |
| result = next(output) | |
| return {"generated_text": result} | |
| except StopIteration: | |
| return {"generated_text": ""} | |
| async def cleanup(self): | |
| """Cleanup method - no longer needed as clients are created per-request""" | |
| pass |