import os import gradio as gr from fastapi import FastAPI from fastapi.responses import RedirectResponse from pydantic import BaseModel from app.ml.predictor import EmotionPredictor app = FastAPI(title="Emotion Classifier", version="1.0.0") MODEL_ID = os.getenv("MODEL_ID", "AfroLogicInsect/emotionClassifier") predictor = EmotionPredictor(MODEL_ID) class PredictRequest(BaseModel): text: str class PredictResponse(BaseModel): emotion: str confidence: float all_emotions: dict[str, float] # --- existing REST endpoint (unchanged) --- @app.post("/predict", response_model=PredictResponse) async def predict(request: PredictRequest): result = predictor.predict(request.text) return PredictResponse(**result) @app.get("/health") async def health(): return {"status": "ok"} # --- Gradio UI --- def classify(text: str): if not text.strip(): return "Please enter some text.", {} result = predictor.predict(text) label = f"{result['emotion'].upper()} ({result['confidence']:.1%} confidence)" return label, result["all_emotions"] with gr.Blocks(title="Emotion Classifier") as demo: gr.Markdown( """ # Emotion Classifier Fine-tuned DistilBERT that detects six emotions: anger, fear, joy, love, sadness, surprise. """ ) with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Input text", placeholder="Type something…", lines=3, ) submit_btn = gr.Button("Classify", variant="primary") with gr.Column(): label_output = gr.Textbox(label="Prediction") scores_output = gr.Label(label="All emotion scores", num_top_classes=6) submit_btn.click( fn=classify, inputs=text_input, outputs=[label_output, scores_output], ) gr.Examples( examples=[ ["I am feeling very happy today!"], ["This is so frustrating, nothing works."], ["I miss my family so much."], ["Wow, I did not expect that at all!"], ], inputs=text_input, ) gr.Markdown( """ --- **API endpoint** also available at `/predict` — see `/docs` for the OpenAPI spec. """ ) # Mount Gradio on FastAPI; the UI lives at "/" and the REST API at "/predict" app = gr.mount_gradio_app(app, demo, path="/")