ollama-api-lfm / app.py
oki692's picture
Upload app.py with huggingface_hub
9a18084 verified
from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import httpx
import os
app = FastAPI(title="Ollama Compatible API Proxy", version="1.0.0")
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
OLLAMA_BASE_URL = "http://localhost:11434"
# Middleware to disable all caching
@app.middleware("http")
async def disable_cache_middleware(request, call_next):
response = await call_next(request)
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
@app.get("/")
async def root():
"""Health check"""
return {"status": "online", "service": "Ollama Compatible API", "model": "deepseek-r1:1.5b"}
# Proxy all Ollama API endpoints
@app.api_route("/api/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy_ollama_api(path: str, request: Request):
"""Proxy all requests to Ollama API - PURE STREAMING, NO BUFFERING"""
# Get request body
body = await request.body()
# Prepare headers
headers = dict(request.headers)
headers.pop("host", None)
# Determine if streaming
is_streaming = False
if body:
try:
import json
data = json.loads(body)
is_streaming = data.get("stream", False)
except:
pass
async with httpx.AsyncClient(timeout=300.0) as client:
# Forward request to Ollama
if is_streaming:
# PURE STREAMING - NO DELAYS, NO BUFFERING
async def stream_response():
async with client.stream(
request.method,
f"{OLLAMA_BASE_URL}/api/{path}",
content=body,
headers=headers,
timeout=300.0
) as response:
async for chunk in response.aiter_raw():
yield chunk
return StreamingResponse(
stream_response(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-store, no-cache, must-revalidate, max-age=0",
"Pragma": "no-cache",
"Expires": "0",
"X-Accel-Buffering": "no"
}
)
else:
# Non-streaming response
response = await client.request(
request.method,
f"{OLLAMA_BASE_URL}/api/{path}",
content=body,
headers=headers,
timeout=300.0
)
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type")
)
# Root level endpoints (for compatibility)
@app.get("/api/tags")
async def list_models():
"""List available models - Ollama compatible"""
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
return Response(
content=response.content,
status_code=response.status_code,
media_type="application/json"
)
@app.post("/api/generate")
async def generate(request: Request):
"""Generate completion - Ollama compatible - PURE STREAMING"""
body = await request.body()
async with httpx.AsyncClient(timeout=300.0) as client:
async def stream_response():
async with client.stream(
"POST",
f"{OLLAMA_BASE_URL}/api/generate",
content=body,
timeout=300.0
) as response:
async for chunk in response.aiter_raw():
yield chunk
return StreamingResponse(
stream_response(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-store, no-cache, must-revalidate, max-age=0",
"Pragma": "no-cache",
"Expires": "0",
"X-Accel-Buffering": "no"
}
)
@app.post("/api/chat")
async def chat(request: Request):
"""Chat completion - Ollama compatible - PURE STREAMING"""
body = await request.body()
async with httpx.AsyncClient(timeout=300.0) as client:
async def stream_response():
async with client.stream(
"POST",
f"{OLLAMA_BASE_URL}/api/chat",
content=body,
timeout=300.0
) as response:
async for chunk in response.aiter_raw():
yield chunk
return StreamingResponse(
stream_response(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-store, no-cache, must-revalidate, max-age=0",
"Pragma": "no-cache",
"Expires": "0",
"X-Accel-Buffering": "no"
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")