Spaces:
Runtime error
Runtime error
| # api.py file in main directory of the Inference API module. | |
| import httpx | |
| from typing import Optional, AsyncIterator, Dict, Any | |
| import logging | |
| from litserve import LitAPI | |
| from pydantic import BaseModel | |
| class GenerationResponse(BaseModel): | |
| generated_text: str | |
| class InferenceApi(LitAPI): | |
| def __init__(self): | |
| """Initialize the Inference API with configuration.""" | |
| super().__init__() | |
| self.logger = logging.getLogger(__name__) | |
| self.logger.info("Initializing Inference API") | |
| self.client = None | |
| async def setup(self, device: Optional[str] = None): | |
| """Setup method required by LitAPI - initialize HTTP client""" | |
| self._device = device | |
| self.client = httpx.AsyncClient( | |
| base_url="http://localhost:8002", # We'll need to make this configurable | |
| timeout=60.0 | |
| ) | |
| self.logger.info(f"Inference API setup completed on device: {device}") | |
| async def predict(self, x: str, **kwargs) -> AsyncIterator[str]: | |
| """ | |
| Main prediction method required by LitAPI. | |
| Always yields, either chunks in streaming mode or complete response in non-streaming mode. | |
| """ | |
| if self.stream: | |
| async for chunk in self.generate_stream(x, **kwargs): | |
| yield chunk | |
| else: | |
| response = await self.generate_response(x, **kwargs) | |
| yield response | |
| def decode_request(self, request: Any, **kwargs) -> str: | |
| """Convert the request payload to input format.""" | |
| if isinstance(request, dict) and "prompt" in request: | |
| return request["prompt"] | |
| return request | |
| def encode_response(self, output: AsyncIterator[str], **kwargs) -> AsyncIterator[Dict[str, str]]: | |
| """Convert the model output to a response payload.""" | |
| async def wrapper(): | |
| async for chunk in output: | |
| yield {"generated_text": chunk} | |
| return wrapper() | |
| async def generate_response( | |
| self, | |
| prompt: str, | |
| system_message: Optional[str] = None, | |
| max_new_tokens: Optional[int] = None | |
| ) -> str: | |
| """Generate a complete response by forwarding the request to the LLM Server.""" | |
| self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") | |
| try: | |
| response = await self.client.post( | |
| "/api/v1/generate", | |
| json={ | |
| "prompt": prompt, | |
| "system_message": system_message, | |
| "max_new_tokens": max_new_tokens | |
| } | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| return data["generated_text"] | |
| except Exception as e: | |
| self.logger.error(f"Error in generate_response: {str(e)}") | |
| raise | |
| async def generate_stream( | |
| self, | |
| prompt: str, | |
| system_message: Optional[str] = None, | |
| max_new_tokens: Optional[int] = None | |
| ) -> AsyncIterator[str]: | |
| """Generate a streaming response by forwarding the request to the LLM Server.""" | |
| self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") | |
| try: | |
| async with self.client.stream( | |
| "POST", | |
| "/api/v1/generate/stream", | |
| json={ | |
| "prompt": prompt, | |
| "system_message": system_message, | |
| "max_new_tokens": max_new_tokens | |
| } | |
| ) as response: | |
| response.raise_for_status() | |
| async for chunk in response.aiter_text(): | |
| yield chunk | |
| except Exception as e: | |
| self.logger.error(f"Error in generate_stream: {str(e)}") | |
| raise | |
| async def cleanup(self): | |
| """Cleanup method - close HTTP client""" | |
| if self.client: | |
| await self.client.aclose() |