Charan Suresh
Fix Sage/Lens silent failures and model loading on HF Space
6180f4c
Raw
History Blame Contribute Delete
2.42 kB
import logging
import threading
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
try:
from inference import generate_text, generate_with_image, health_check, load_model
except ModuleNotFoundError:
from .inference import generate_text, generate_with_image, health_check, load_model
logger = logging.getLogger(__name__)
def _preload_model() -> None:
"""Load the model in the background at startup so the first inference request
doesn't pay the full cold-start cost."""
try:
load_model()
logger.info("Model preloaded at startup")
except Exception:
logger.exception("Background model preload failed — will retry on first request")
@asynccontextmanager
async def lifespan(app: FastAPI):
threading.Thread(target=_preload_model, daemon=True, name="model-preloader").start()
yield
app = FastAPI(lifespan=lifespan)
class GenerateRequest(BaseModel):
prompt: str
model: str = "e4b"
max_tokens: int = 512
class VisionRequest(BaseModel):
prompt: str
image_base64: str
model: str = "e4b"
max_tokens: int = 512
@app.get("/")
def root():
return {"message": "KnowLedge Inference API is running"}
@app.get("/api/health")
def health():
# health_check() is now non-blocking — just reports current load state
return health_check()
@app.post("/api/generate")
def generate(req: GenerateRequest):
model_name = "e2b" if "e2b" in req.model.lower() else "e4b"
try:
response = generate_text(model_name, req.prompt, req.max_tokens)
return {"response": response}
except RuntimeError as exc:
raise HTTPException(status_code=503, detail=str(exc))
except Exception as exc:
logger.exception("Text generation error")
raise HTTPException(status_code=500, detail=str(exc))
@app.post("/api/generate_vision")
def generate_vision(req: VisionRequest):
try:
response = generate_with_image(req.prompt, req.image_base64, req.max_tokens)
return {"response": response}
except RuntimeError as exc:
raise HTTPException(status_code=503, detail=str(exc))
except Exception as exc:
logger.exception("Vision generation error")
raise HTTPException(status_code=500, detail=str(exc))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)