| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
|
|
| import gradio as gr |
| from huggingface_hub import snapshot_download |
|
|
| from src.sentence_classifier.predict import DEFAULT_ARTIFACT_DIR, SentenceFunctionClassifier |
|
|
|
|
| def resolve_artifact_dir() -> Path: |
| local_dir = Path(os.environ.get("MODEL_DIR", DEFAULT_ARTIFACT_DIR)) |
| if (local_dir / "classifier.joblib").exists(): |
| return local_dir |
|
|
| model_repo_id = os.environ.get("MODEL_REPO_ID") |
| if model_repo_id: |
| return Path( |
| snapshot_download( |
| repo_id=model_repo_id, |
| allow_patterns=["*.joblib", "metadata.json"], |
| ) |
| ) |
|
|
| return local_dir |
|
|
|
|
| classifier: SentenceFunctionClassifier | None = None |
| load_error: str | None = None |
|
|
| try: |
| classifier = SentenceFunctionClassifier(resolve_artifact_dir()) |
| except Exception as exc: |
| load_error = str(exc) |
|
|
|
|
| def classify(sentence: str) -> tuple[str, dict[str, float]]: |
| if classifier is None: |
| message = ( |
| "Model artifacts are not available yet. Train locally with " |
| "`python -m src.sentence_classifier.train`, or set MODEL_REPO_ID " |
| "in the Hugging Face Space." |
| ) |
| if load_error: |
| message = f"{message}\n\nLoad error: {load_error}" |
| return message, {} |
|
|
| try: |
| prediction = classifier.predict(sentence) |
| except ValueError as exc: |
| return str(exc), {} |
|
|
| label = prediction.label.replace("_", " ").title() |
| summary = f"{label} ({prediction.confidence:.1%} confidence)" |
| return summary, prediction.probabilities |
|
|
|
|
| examples = [ |
| "The lecture starts at nine.", |
| "Please open the window.", |
| "Where did you put the keys?", |
| "What a remarkable answer!", |
| "May your work be rewarded.", |
| ] |
|
|
|
|
| demo = gr.Interface( |
| fn=classify, |
| inputs=gr.Textbox(label="Sentence", placeholder="Enter one sentence..."), |
| outputs=[ |
| gr.Textbox(label="Prediction"), |
| gr.Label(label="Class probabilities"), |
| ], |
| examples=examples, |
| title="Sentence Function Classifier", |
| description=( |
| "Classifies sentences as declarative, imperative, interrogative, " |
| "exclamatory, or optative." |
| ), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|