MarcusBennevall's picture
Upload app.py with huggingface_hub
eb947f7 verified
from __future__ import annotations
import os
from pathlib import Path
import gradio as gr
from huggingface_hub import snapshot_download
from src.sentence_classifier.predict import DEFAULT_ARTIFACT_DIR, SentenceFunctionClassifier
def resolve_artifact_dir() -> Path:
local_dir = Path(os.environ.get("MODEL_DIR", DEFAULT_ARTIFACT_DIR))
if (local_dir / "classifier.joblib").exists():
return local_dir
model_repo_id = os.environ.get("MODEL_REPO_ID")
if model_repo_id:
return Path(
snapshot_download(
repo_id=model_repo_id,
allow_patterns=["*.joblib", "metadata.json"],
)
)
return local_dir
classifier: SentenceFunctionClassifier | None = None
load_error: str | None = None
try:
classifier = SentenceFunctionClassifier(resolve_artifact_dir())
except Exception as exc:
load_error = str(exc)
def classify(sentence: str) -> tuple[str, dict[str, float]]:
if classifier is None:
message = (
"Model artifacts are not available yet. Train locally with "
"`python -m src.sentence_classifier.train`, or set MODEL_REPO_ID "
"in the Hugging Face Space."
)
if load_error:
message = f"{message}\n\nLoad error: {load_error}"
return message, {}
try:
prediction = classifier.predict(sentence)
except ValueError as exc:
return str(exc), {}
label = prediction.label.replace("_", " ").title()
summary = f"{label} ({prediction.confidence:.1%} confidence)"
return summary, prediction.probabilities
examples = [
"The lecture starts at nine.",
"Please open the window.",
"Where did you put the keys?",
"What a remarkable answer!",
"May your work be rewarded.",
]
demo = gr.Interface(
fn=classify,
inputs=gr.Textbox(label="Sentence", placeholder="Enter one sentence..."),
outputs=[
gr.Textbox(label="Prediction"),
gr.Label(label="Class probabilities"),
],
examples=examples,
title="Sentence Function Classifier",
description=(
"Classifies sentences as declarative, imperative, interrogative, "
"exclamatory, or optative."
),
)
if __name__ == "__main__":
demo.launch()