AfroLogicInsect's picture
initial migration from Render: FastAPI + Gradio, port 7860
811fc57
Raw
History Blame Contribute Delete
2.44 kB
import os
import gradio as gr
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from app.ml.predictor import EmotionPredictor
app = FastAPI(title="Emotion Classifier", version="1.0.0")
MODEL_ID = os.getenv("MODEL_ID", "AfroLogicInsect/emotionClassifier")
predictor = EmotionPredictor(MODEL_ID)
class PredictRequest(BaseModel):
text: str
class PredictResponse(BaseModel):
emotion: str
confidence: float
all_emotions: dict[str, float]
# --- existing REST endpoint (unchanged) ---
@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
result = predictor.predict(request.text)
return PredictResponse(**result)
@app.get("/health")
async def health():
return {"status": "ok"}
# --- Gradio UI ---
def classify(text: str):
if not text.strip():
return "Please enter some text.", {}
result = predictor.predict(text)
label = f"{result['emotion'].upper()} ({result['confidence']:.1%} confidence)"
return label, result["all_emotions"]
with gr.Blocks(title="Emotion Classifier") as demo:
gr.Markdown(
"""
# Emotion Classifier
Fine-tuned DistilBERT that detects six emotions: anger, fear, joy, love, sadness, surprise.
"""
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Input text",
placeholder="Type something…",
lines=3,
)
submit_btn = gr.Button("Classify", variant="primary")
with gr.Column():
label_output = gr.Textbox(label="Prediction")
scores_output = gr.Label(label="All emotion scores", num_top_classes=6)
submit_btn.click(
fn=classify,
inputs=text_input,
outputs=[label_output, scores_output],
)
gr.Examples(
examples=[
["I am feeling very happy today!"],
["This is so frustrating, nothing works."],
["I miss my family so much."],
["Wow, I did not expect that at all!"],
],
inputs=text_input,
)
gr.Markdown(
"""
---
**API endpoint** also available at `/predict` — see `/docs` for the OpenAPI spec.
"""
)
# Mount Gradio on FastAPI; the UI lives at "/" and the REST API at "/predict"
app = gr.mount_gradio_app(app, demo, path="/")