topic-model / app.py
abyayel's picture
Update app.py
cea26d1 verified
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from transformers import pipeline
import os
from typing import List, Optional
app = FastAPI()
# API Key security – read from environment secret
API_KEY = os.getenv("INTERNAL_API_KEY")
if not API_KEY:
raise RuntimeError("INTERNAL_API_KEY environment variable not set")
API_KEY_NAME = "X-Internal-API-Key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
async def verify_api_key(api_key: str = Depends(api_key_header)):
if not api_key or api_key != API_KEY:
raise HTTPException(status_code=403, detail="Invalid API Key")
return api_key
model = pipeline("zero-shot-classification", model="vicgalle/xlm-roberta-large-xnli-anli")
DEFAULT_TOPICS = [
"Health", "Education", "Water Supply", "Electricity", "Housing", "Transport",
"Roads", "Bridges", "Railways", "Airports", "Digital Infrastructure",
"Agriculture", "Irrigation", "Livestock", "Forestry", "Environment", "Climate Change",
"Economy", "Employment", "Small Business", "Industry", "Trade", "Tourism",
"Social Protection", "Pension", "Disability Support", "Food Security", "Poverty Reduction",
"Governance", "Justice", "Police", "Defense", "Public Safety",
"Urban Planning", "Rural Development", "Land Administration", "Migration",
"Sports", "Culture", "Youth", "Women Affairs", "Diaspora"
]
class TopicRequest(BaseModel):
text: str
candidate_topics: Optional[List[str]] = None
class TopicResponse(BaseModel):
topics: List[dict]
@app.post("/suggest-topics", response_model=TopicResponse)
async def suggest_topics(request: TopicRequest, _ = Depends(verify_api_key)):
if not request.text.strip():
raise HTTPException(status_code=400, detail="Empty text")
candidate = request.candidate_topics or DEFAULT_TOPICS
result = model(request.text, candidate)
top = [{"topic": result["labels"][i], "confidence": result["scores"][i]} for i in range(min(3, len(result["labels"])))]
return {"topics": top}
@app.get("/health")
async def health():
return {"status": "ok"}