|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import Optional |
|
|
import httpx |
|
|
import logging |
|
|
import time |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Ollama Generate API", |
|
|
description="Simple REST API for Ollama text generation", |
|
|
version="1.0.0", |
|
|
docs_url="/docs", |
|
|
redoc_url="/redoc" |
|
|
) |
|
|
|
|
|
|
|
|
OLLAMA_BASE_URL = "http://localhost:11434" |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
model: str = Field(..., description="Model name to use for generation") |
|
|
prompt: str = Field(..., description="Input prompt for text generation") |
|
|
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") |
|
|
top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter") |
|
|
max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate") |
|
|
|
|
|
class GenerateResponse(BaseModel): |
|
|
model: str |
|
|
response: str |
|
|
done: bool |
|
|
total_duration: Optional[int] = None |
|
|
load_duration: Optional[int] = None |
|
|
prompt_eval_count: Optional[int] = None |
|
|
eval_count: Optional[int] = None |
|
|
|
|
|
|
|
|
async def get_ollama_client(): |
|
|
return httpx.AsyncClient(timeout=300.0) |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
try: |
|
|
async with await get_ollama_client() as client: |
|
|
response = await client.get(f"{OLLAMA_BASE_URL}/api/version") |
|
|
if response.status_code == 200: |
|
|
return { |
|
|
"status": "healthy", |
|
|
"ollama_status": "running", |
|
|
"timestamp": time.time() |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"status": "degraded", |
|
|
"ollama_status": "error", |
|
|
"error": f"Ollama returned status {response.status_code}", |
|
|
"timestamp": time.time() |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Health check failed: {e}") |
|
|
return { |
|
|
"status": "unhealthy", |
|
|
"ollama_status": "unreachable", |
|
|
"error": str(e), |
|
|
"timestamp": time.time() |
|
|
} |
|
|
|
|
|
@app.post("/generate", response_model=GenerateResponse) |
|
|
async def generate_text(request: GenerateRequest): |
|
|
"""Generate text completion using Ollama""" |
|
|
try: |
|
|
generate_data = { |
|
|
"model": request.model, |
|
|
"prompt": request.prompt, |
|
|
"stream": False, |
|
|
"options": { |
|
|
"temperature": request.temperature, |
|
|
"top_p": request.top_p, |
|
|
"num_predict": request.max_tokens |
|
|
} |
|
|
} |
|
|
|
|
|
logger.info(f"Generating text with model: {request.model}") |
|
|
|
|
|
async with await get_ollama_client() as client: |
|
|
response = await client.post( |
|
|
f"{OLLAMA_BASE_URL}/api/generate", |
|
|
json=generate_data, |
|
|
timeout=300.0 |
|
|
) |
|
|
|
|
|
if response.status_code == 404: |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Model '{request.model}' not found. Make sure the model is pulled and available." |
|
|
) |
|
|
|
|
|
response.raise_for_status() |
|
|
result = response.json() |
|
|
|
|
|
return GenerateResponse( |
|
|
model=result.get("model", request.model), |
|
|
response=result.get("response", ""), |
|
|
done=result.get("done", True), |
|
|
total_duration=result.get("total_duration"), |
|
|
load_duration=result.get("load_duration"), |
|
|
prompt_eval_count=result.get("prompt_eval_count"), |
|
|
eval_count=result.get("eval_count") |
|
|
) |
|
|
|
|
|
except httpx.HTTPError as e: |
|
|
logger.error(f"Generate request failed: Status {e.response.status_code}") |
|
|
if e.response.status_code == 404: |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Model '{request.model}' not found. Make sure it's installed." |
|
|
) |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Generation failed: {str(e)}" |
|
|
) |
|
|
except httpx.TimeoutException: |
|
|
logger.error("Generate request timed out") |
|
|
raise HTTPException( |
|
|
status_code=408, |
|
|
detail="Request timed out. Try with a shorter prompt or smaller max_tokens." |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error in generate: {e}") |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Unexpected error: {str(e)}" |
|
|
) |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint with API information""" |
|
|
return { |
|
|
"message": "Ollama Generate API", |
|
|
"version": "1.0.0", |
|
|
"endpoints": { |
|
|
"health": "/health - Check if Ollama is running", |
|
|
"generate": "/generate - Generate text using Ollama models", |
|
|
"docs": "/docs - API documentation" |
|
|
}, |
|
|
"usage": { |
|
|
"example": { |
|
|
"url": "/generate", |
|
|
"method": "POST", |
|
|
"body": { |
|
|
"model": "tinyllama", |
|
|
"prompt": "Hello, how are you?", |
|
|
"temperature": 0.7, |
|
|
"max_tokens": 100 |
|
|
} |
|
|
} |
|
|
}, |
|
|
"status": "running" |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
logger.info("Starting Ollama Generate API server...") |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |