Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from fastapi.responses import RedirectResponse | |
| from pydantic import BaseModel | |
| from app.ml.predictor import EmotionPredictor | |
| app = FastAPI(title="Emotion Classifier", version="1.0.0") | |
| MODEL_ID = os.getenv("MODEL_ID", "AfroLogicInsect/emotionClassifier") | |
| predictor = EmotionPredictor(MODEL_ID) | |
| class PredictRequest(BaseModel): | |
| text: str | |
| class PredictResponse(BaseModel): | |
| emotion: str | |
| confidence: float | |
| all_emotions: dict[str, float] | |
| # --- existing REST endpoint (unchanged) --- | |
| async def predict(request: PredictRequest): | |
| result = predictor.predict(request.text) | |
| return PredictResponse(**result) | |
| async def health(): | |
| return {"status": "ok"} | |
| # --- Gradio UI --- | |
| def classify(text: str): | |
| if not text.strip(): | |
| return "Please enter some text.", {} | |
| result = predictor.predict(text) | |
| label = f"{result['emotion'].upper()} ({result['confidence']:.1%} confidence)" | |
| return label, result["all_emotions"] | |
| with gr.Blocks(title="Emotion Classifier") as demo: | |
| gr.Markdown( | |
| """ | |
| # Emotion Classifier | |
| Fine-tuned DistilBERT that detects six emotions: anger, fear, joy, love, sadness, surprise. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Input text", | |
| placeholder="Type something…", | |
| lines=3, | |
| ) | |
| submit_btn = gr.Button("Classify", variant="primary") | |
| with gr.Column(): | |
| label_output = gr.Textbox(label="Prediction") | |
| scores_output = gr.Label(label="All emotion scores", num_top_classes=6) | |
| submit_btn.click( | |
| fn=classify, | |
| inputs=text_input, | |
| outputs=[label_output, scores_output], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["I am feeling very happy today!"], | |
| ["This is so frustrating, nothing works."], | |
| ["I miss my family so much."], | |
| ["Wow, I did not expect that at all!"], | |
| ], | |
| inputs=text_input, | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **API endpoint** also available at `/predict` — see `/docs` for the OpenAPI spec. | |
| """ | |
| ) | |
| # Mount Gradio on FastAPI; the UI lives at "/" and the REST API at "/predict" | |
| app = gr.mount_gradio_app(app, demo, path="/") |