from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from contextlib import asynccontextmanager import re import os import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) try: from llama_cpp import Llama except ImportError: raise ImportError("Install llama-cpp-python: pip install llama-cpp-python") MODEL_REPO = "bartowski/Phi-3.5-mini-instruct-GGUF" MODEL_FILE = "Phi-3.5-mini-instruct-Q4_K_M.gguf" llm = None model_loading = False @asynccontextmanager async def lifespan(app: FastAPI): global llm, model_loading try: logger.info("🚀 Starting model load...") model_loading = True # Set cache directory for Hugging Face Spaces cache_dir = os.getenv("HF_HOME", "./models") llm = Llama.from_pretrained( repo_id=MODEL_REPO, filename=MODEL_FILE, n_threads=4, n_ctx=2048, n_batch=256, n_gpu_layers=0, verbose=False, ) model_loading = False logger.info("✅ Model loaded and ready") except Exception as e: logger.error(f"❌ Model load error: {e}") model_loading = False llm = None yield logger.info("🛑 Shutting down...") if llm: del llm app = FastAPI( title="AI Summarizer", description="Fast & Accurate AI Text Summarizer", version="1.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class SummarizeRequest(BaseModel): text: str = Field(..., min_length=1, max_length=2000) length: str = Field(default="short", pattern="^(short|medium|long)$") LENGTH_INSTRUCTIONS = { "short": "Summarize in 2–3 concise sentences.", "medium": "Summarize in 4–5 clear sentences.", "long": "Summarize in a detailed paragraph.", } def clean_output(text: str) -> str: """Clean model output from special tokens""" text = re.sub(r"<\|.*?\|>", "", text) text = re.sub(r"\s+", " ", text) return text.strip() @app.get("/") def root(): """Root endpoint - returns status""" return { "status": "healthy", "model_loaded": llm is not None, "model_loading": model_loading, "message": "AI Summarizer API is running" } @app.get("/health") def health(): """Health check endpoint for container orchestration""" if model_loading: return { "status": "starting", "model_loaded": False, "model_loading": True, "message": "Model is loading, please wait..." } if llm is None: return { "status": "unhealthy", "model_loaded": False, "model_loading": False, "message": "Model failed to load" } return { "status": "healthy", "model_loaded": True, "model_loading": False, "model_name": MODEL_FILE, "message": "Ready to summarize" } @app.get("/ready") def readiness(): """Readiness probe - returns 200 only when model is loaded""" if llm is not None and not model_loading: return {"ready": True} raise HTTPException(status_code=503, detail="Model not ready") @app.post("/api/summarize") async def summarize(req: SummarizeRequest): if model_loading: raise HTTPException( status_code=503, detail="Model is still loading. Please wait and try again." ) if llm is None: raise HTTPException( status_code=503, detail="Model not loaded. Check server logs." ) try: text = req.text.strip() length_instruction = LENGTH_INSTRUCTIONS.get( req.length, LENGTH_INSTRUCTIONS["short"] ) prompt = f"""<|user|> You are an expert text summarizer. {length_instruction} Text: {text} <|end|> <|assistant|>""" max_tokens_map = { "short": 140, "medium": 220, "long": 300 } logger.info(f"Summarizing text (length: {req.length})") output = llm( prompt, max_tokens=max_tokens_map.get(req.length, 140), temperature=0.3, top_p=0.9, top_k=40, repeat_penalty=1.05, stop=["<|end|>", "<|user|>"], echo=False ) summary = clean_output(output["choices"][0]["text"]) if not summary: raise HTTPException( status_code=500, detail="Model produced empty output" ) logger.info("✅ Summary generated successfully") return { "summary": summary, "success": True, "length": req.length } except HTTPException: raise except Exception as e: logger.error(f"Summarization error: {e}") raise HTTPException( status_code=500, detail=f"Summarization error: {str(e)}" ) if __name__ == "__main__": import uvicorn # Use PORT environment variable for Hugging Face Spaces port = int(os.getenv("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)