Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from typing import Optional | |
| import httpx | |
| import logging | |
| import time | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # FastAPI app | |
| app = FastAPI( | |
| title="Ollama Generate API", | |
| description="Simple REST API for Ollama text generation", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Ollama server configuration | |
| OLLAMA_BASE_URL = "http://localhost:11434" | |
| # Pydantic models | |
| class GenerateRequest(BaseModel): | |
| model: str = Field(..., description="Model name to use for generation") | |
| prompt: str = Field(..., description="Input prompt for text generation") | |
| temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") | |
| top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter") | |
| max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate") | |
| class GenerateResponse(BaseModel): | |
| model: str | |
| response: str | |
| done: bool | |
| total_duration: Optional[int] = None | |
| load_duration: Optional[int] = None | |
| prompt_eval_count: Optional[int] = None | |
| eval_count: Optional[int] = None | |
| # HTTP client for Ollama API | |
| async def get_ollama_client(): | |
| return httpx.AsyncClient(timeout=300.0) # 5 minute timeout | |
| async def health_check(): | |
| """Health check endpoint""" | |
| try: | |
| async with await get_ollama_client() as client: | |
| response = await client.get(f"{OLLAMA_BASE_URL}/api/version") | |
| if response.status_code == 200: | |
| return { | |
| "status": "healthy", | |
| "ollama_status": "running", | |
| "timestamp": time.time() | |
| } | |
| else: | |
| return { | |
| "status": "degraded", | |
| "ollama_status": "error", | |
| "error": f"Ollama returned status {response.status_code}", | |
| "timestamp": time.time() | |
| } | |
| except Exception as e: | |
| logger.error(f"Health check failed: {e}") | |
| return { | |
| "status": "unhealthy", | |
| "ollama_status": "unreachable", | |
| "error": str(e), | |
| "timestamp": time.time() | |
| } | |
| async def generate_text(request: GenerateRequest): | |
| """Generate text completion using Ollama""" | |
| try: | |
| generate_data = { | |
| "model": request.model, | |
| "prompt": request.prompt, | |
| "stream": False, # Always non-streaming for simplicity | |
| "options": { | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "num_predict": request.max_tokens | |
| } | |
| } | |
| logger.info(f"Generating text with model: {request.model}") | |
| async with await get_ollama_client() as client: | |
| response = await client.post( | |
| f"{OLLAMA_BASE_URL}/api/generate", | |
| json=generate_data, | |
| timeout=300.0 | |
| ) | |
| if response.status_code == 404: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Model '{request.model}' not found. Make sure the model is pulled and available." | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| return GenerateResponse( | |
| model=result.get("model", request.model), | |
| response=result.get("response", ""), | |
| done=result.get("done", True), | |
| total_duration=result.get("total_duration"), | |
| load_duration=result.get("load_duration"), | |
| prompt_eval_count=result.get("prompt_eval_count"), | |
| eval_count=result.get("eval_count") | |
| ) | |
| except httpx.HTTPError as e: | |
| logger.error(f"Generate request failed: Status {e.response.status_code}") | |
| if e.response.status_code == 404: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Model '{request.model}' not found. Make sure it's installed." | |
| ) | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Generation failed: {str(e)}" | |
| ) | |
| except httpx.TimeoutException: | |
| logger.error("Generate request timed out") | |
| raise HTTPException( | |
| status_code=408, | |
| detail="Request timed out. Try with a shorter prompt or smaller max_tokens." | |
| ) | |
| except Exception as e: | |
| logger.error(f"Unexpected error in generate: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Unexpected error: {str(e)}" | |
| ) | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "Ollama Generate API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "health": "/health - Check if Ollama is running", | |
| "generate": "/generate - Generate text using Ollama models", | |
| "docs": "/docs - API documentation" | |
| }, | |
| "usage": { | |
| "example": { | |
| "url": "/generate", | |
| "method": "POST", | |
| "body": { | |
| "model": "tinyllama", | |
| "prompt": "Hello, how are you?", | |
| "temperature": 0.7, | |
| "max_tokens": 100 | |
| } | |
| } | |
| }, | |
| "status": "running" | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info("Starting Ollama Generate API server...") | |
| uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |