Summarizer / app.py
viskav's picture
Update app.py
105b25f verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from contextlib import asynccontextmanager
import re
import os
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
from llama_cpp import Llama
except ImportError:
raise ImportError("Install llama-cpp-python: pip install llama-cpp-python")
MODEL_REPO = "bartowski/Phi-3.5-mini-instruct-GGUF"
MODEL_FILE = "Phi-3.5-mini-instruct-Q4_K_M.gguf"
llm = None
model_loading = False
@asynccontextmanager
async def lifespan(app: FastAPI):
global llm, model_loading
try:
logger.info("πŸš€ Starting model load...")
model_loading = True
# Set cache directory for Hugging Face Spaces
cache_dir = os.getenv("HF_HOME", "./models")
llm = Llama.from_pretrained(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
n_threads=4,
n_ctx=2048,
n_batch=256,
n_gpu_layers=0,
verbose=False,
)
model_loading = False
logger.info("βœ… Model loaded and ready")
except Exception as e:
logger.error(f"❌ Model load error: {e}")
model_loading = False
llm = None
yield
logger.info("πŸ›‘ Shutting down...")
if llm:
del llm
app = FastAPI(
title="AI Summarizer",
description="Fast & Accurate AI Text Summarizer",
version="1.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class SummarizeRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=2000)
length: str = Field(default="short", pattern="^(short|medium|long)$")
LENGTH_INSTRUCTIONS = {
"short": "Summarize in 2–3 concise sentences.",
"medium": "Summarize in 4–5 clear sentences.",
"long": "Summarize in a detailed paragraph.",
}
def clean_output(text: str) -> str:
"""Clean model output from special tokens"""
text = re.sub(r"<\|.*?\|>", "", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
@app.get("/")
def root():
"""Root endpoint - returns status"""
return {
"status": "healthy",
"model_loaded": llm is not None,
"model_loading": model_loading,
"message": "AI Summarizer API is running"
}
@app.get("/health")
def health():
"""Health check endpoint for container orchestration"""
if model_loading:
return {
"status": "starting",
"model_loaded": False,
"model_loading": True,
"message": "Model is loading, please wait..."
}
if llm is None:
return {
"status": "unhealthy",
"model_loaded": False,
"model_loading": False,
"message": "Model failed to load"
}
return {
"status": "healthy",
"model_loaded": True,
"model_loading": False,
"model_name": MODEL_FILE,
"message": "Ready to summarize"
}
@app.get("/ready")
def readiness():
"""Readiness probe - returns 200 only when model is loaded"""
if llm is not None and not model_loading:
return {"ready": True}
raise HTTPException(status_code=503, detail="Model not ready")
@app.post("/api/summarize")
async def summarize(req: SummarizeRequest):
if model_loading:
raise HTTPException(
status_code=503,
detail="Model is still loading. Please wait and try again."
)
if llm is None:
raise HTTPException(
status_code=503,
detail="Model not loaded. Check server logs."
)
try:
text = req.text.strip()
length_instruction = LENGTH_INSTRUCTIONS.get(
req.length,
LENGTH_INSTRUCTIONS["short"]
)
prompt = f"""<|user|>
You are an expert text summarizer.
{length_instruction}
Text:
{text}
<|end|>
<|assistant|>"""
max_tokens_map = {
"short": 140,
"medium": 220,
"long": 300
}
logger.info(f"Summarizing text (length: {req.length})")
output = llm(
prompt,
max_tokens=max_tokens_map.get(req.length, 140),
temperature=0.3,
top_p=0.9,
top_k=40,
repeat_penalty=1.05,
stop=["<|end|>", "<|user|>"],
echo=False
)
summary = clean_output(output["choices"][0]["text"])
if not summary:
raise HTTPException(
status_code=500,
detail="Model produced empty output"
)
logger.info("βœ… Summary generated successfully")
return {
"summary": summary,
"success": True,
"length": req.length
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Summarization error: {e}")
raise HTTPException(
status_code=500,
detail=f"Summarization error: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
# Use PORT environment variable for Hugging Face Spaces
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)