File size: 1,766 Bytes
63c461f
 
 
c51e08b
63c461f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c51e08b
 
 
63c461f
c51e08b
63c461f
 
 
 
 
 
 
c51e08b
63c461f
 
 
 
 
 
 
c51e08b
 
63c461f
 
 
c51e08b
 
63c461f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from fastapi import FastAPI, HTTPException
from typing import Union, List
from contextlib import asynccontextmanager
import os
import uvicorn

from predict_urgency_model import UrgencyPredictor
from response_schema import TextInput, UrgencyClassificationOutput
from huggingface_hub import HfApi


# Model repository setup

model_repo = os.getenv("MODEL_REPO", "sambodhan/sambodhan_urgency_classifier")

# Hugging Face API for version info
hf_api = HfApi()


# Startup and shutdown

@asynccontextmanager
async def lifespan(app: FastAPI):
    global predictor
    predictor = UrgencyPredictor(model_repo=model_repo)
    yield


# FastAPI app

app = FastAPI(
    title="Sambodhan Urgency Classifier API",
    description="AI model that classifies citizen grievances by urgency with confidence scores.",
    version="1.0.0",
    lifespan=lifespan
)


# Routes

@app.post("/predict_urgency", response_model=Union[UrgencyClassificationOutput, List[UrgencyClassificationOutput]])
def predict_urgency(input_data: TextInput):
    try:
        prediction = predictor.predict(input_data.text)
        return prediction
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")

@app.get("/")
def root():
    latest_tag = None
    try:
        latest_tag = hf_api.list_repo_refs(repo_id=model_repo, repo_type="model").tags[0].name
    except Exception:
        latest_tag = "unknown"

    return {
        "message": "Sambodhan Urgency Classifier API is running.",
        "status": "Active" if predictor else "Inactive",
        "model_version": latest_tag
    }


# For local testing (optional)

# if __name__ == "__main__":
#     port = int(os.getenv("PORT", 7860))
#     uvicorn.run("app:app", host="0.0.0.0", port=port)