POC2PROD / api.py
maxcasado's picture
Update api.py
37ced21 verified
raw
history blame contribute delete
390 Bytes
# api.py
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict, Any
from model_utils import predict_proba
app = FastAPI(title="StackOverflow Tagger API")
class Query(BaseModel):
text: str
top_k: int = 10
@app.post("/predict")
def predict(q: Query) -> Dict[str, Any]:
tags = predict_proba(q.text, top_k=q.top_k)
return {"tags": tags}