oki692's picture
Upload app.py with huggingface_hub
ca418d0 verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import httpx
import os
import asyncio
import json
app = FastAPI(title="Ollama Streaming API", version="1.0.0")
# CORS middleware for browser access
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Hardcoded configuration
MODEL_NAME = "deepseek-r1:1.5b"
CONNECT_KEY = "manus-ollama-2024"
OLLAMA_BASE_URL = "http://localhost:11434"
class ChatRequest(BaseModel):
prompt: str
key: str
class HealthResponse(BaseModel):
status: str
model: str
endpoint: str
# Middleware to disable all caching
@app.middleware("http")
async def disable_cache_middleware(request, call_next):
response = await call_next(request)
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
@app.get("/", response_model=HealthResponse)
async def root():
"""Health check endpoint"""
space_url = os.getenv("SPACE_URL", "http://localhost:7860")
return HealthResponse(
status="online",
model=MODEL_NAME,
endpoint=space_url
)
@app.get("/health")
async def health():
"""Detailed health check"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
if response.status_code == 200:
return {"status": "healthy", "ollama": "connected", "model": MODEL_NAME}
except Exception as e:
return {"status": "degraded", "ollama": "disconnected", "error": str(e)}
async def generate_stream(prompt: str):
"""Generate streaming response from Ollama without caching"""
try:
async with httpx.AsyncClient(timeout=300.0) as client:
payload = {
"model": MODEL_NAME,
"prompt": prompt,
"stream": True,
"options": {
"temperature": 0.7,
"num_predict": 2048,
"top_k": 40,
"top_p": 0.9,
"num_ctx": 2048,
"num_batch": 512,
"num_gpu": 1,
"num_thread": 4,
}
}
async with client.stream(
"POST",
f"{OLLAMA_BASE_URL}/api/generate",
json=payload,
timeout=300.0
) as response:
if response.status_code != 200:
yield f"data: {json.dumps({'error': 'Ollama API error'})}\n\n"
return
async for line in response.aiter_lines():
if line.strip():
try:
data = json.loads(line)
if "response" in data:
yield f"data: {json.dumps({'text': data['response'], 'done': data.get('done', False)})}\n\n"
if data.get("done", False):
break
except json.JSONDecodeError:
continue
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
@app.post("/stream")
async def stream_chat(request: ChatRequest):
"""Stream chat completions with key authentication - NO CACHING"""
if request.key != CONNECT_KEY:
raise HTTPException(status_code=403, detail="Invalid connect key")
if not request.prompt or len(request.prompt.strip()) == 0:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
return StreamingResponse(
generate_stream(request.prompt),
media_type="text/event-stream",
headers={
"Cache-Control": "no-store, no-cache, must-revalidate, max-age=0, private",
"Pragma": "no-cache",
"Expires": "0",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"X-Content-Type-Options": "nosniff"
}
)
@app.get("/models")
async def list_models():
"""List available models"""
return {"models": [MODEL_NAME], "default": MODEL_NAME}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")