from __future__ import annotations import os from pathlib import Path import gradio as gr from huggingface_hub import snapshot_download from src.sentence_classifier.predict import DEFAULT_ARTIFACT_DIR, SentenceFunctionClassifier def resolve_artifact_dir() -> Path: local_dir = Path(os.environ.get("MODEL_DIR", DEFAULT_ARTIFACT_DIR)) if (local_dir / "classifier.joblib").exists(): return local_dir model_repo_id = os.environ.get("MODEL_REPO_ID") if model_repo_id: return Path( snapshot_download( repo_id=model_repo_id, allow_patterns=["*.joblib", "metadata.json"], ) ) return local_dir classifier: SentenceFunctionClassifier | None = None load_error: str | None = None try: classifier = SentenceFunctionClassifier(resolve_artifact_dir()) except Exception as exc: load_error = str(exc) def classify(sentence: str) -> tuple[str, dict[str, float]]: if classifier is None: message = ( "Model artifacts are not available yet. Train locally with " "`python -m src.sentence_classifier.train`, or set MODEL_REPO_ID " "in the Hugging Face Space." ) if load_error: message = f"{message}\n\nLoad error: {load_error}" return message, {} try: prediction = classifier.predict(sentence) except ValueError as exc: return str(exc), {} label = prediction.label.replace("_", " ").title() summary = f"{label} ({prediction.confidence:.1%} confidence)" return summary, prediction.probabilities examples = [ "The lecture starts at nine.", "Please open the window.", "Where did you put the keys?", "What a remarkable answer!", "May your work be rewarded.", ] demo = gr.Interface( fn=classify, inputs=gr.Textbox(label="Sentence", placeholder="Enter one sentence..."), outputs=[ gr.Textbox(label="Prediction"), gr.Label(label="Class probabilities"), ], examples=examples, title="Sentence Function Classifier", description=( "Classifies sentences as declarative, imperative, interrogative, " "exclamatory, or optative." ), ) if __name__ == "__main__": demo.launch()