from __future__ import annotations import json from pathlib import Path from typing import Any import streamlit as st from inference import ArticleClassifier, ClassifierError PROJECT_DIR = Path(__file__).resolve().parent CONFIG_PATH = PROJECT_DIR / "configs" / "app_config.json" METRICS_PATH = PROJECT_DIR / "artifacts" / "large_model" / "metrics.json" DEFAULT_APP_CONFIG = { "model_dir": "artifacts/large_model/best_model", "labels_path": "data/processed_large/label_mapping.json", "max_length": 256, "coverage_threshold": 0.95, "model_name": "distilbert-base-uncased", "page_title": "arXiv Topic Classifier", "page_icon": "📚", "example_title": "Learning-based Visual Navigation for Mobile Robots", "example_abstract": ( "We present a transformer-based navigation system that uses camera observations " "and scene understanding to plan robust trajectories for indoor mobile robots." ), } def load_app_config() -> dict[str, Any]: if not CONFIG_PATH.exists(): return DEFAULT_APP_CONFIG.copy() with CONFIG_PATH.open("r", encoding="utf-8") as fh: config = json.load(fh) merged_config = DEFAULT_APP_CONFIG.copy() merged_config.update(config) return merged_config APP_CONFIG = load_app_config() MODEL_DIR = PROJECT_DIR / str(APP_CONFIG["model_dir"]) LABELS_PATH = PROJECT_DIR / str(APP_CONFIG["labels_path"]) MAX_LENGTH = int(APP_CONFIG["max_length"]) COVERAGE_THRESHOLD = float(APP_CONFIG["coverage_threshold"]) st.set_page_config( page_title=str(APP_CONFIG["page_title"]), page_icon=str(APP_CONFIG["page_icon"]), layout="centered", ) @st.cache_resource def load_classifier() -> ArticleClassifier: return ArticleClassifier( model_dir=MODEL_DIR, labels_path=LABELS_PATH, max_length=MAX_LENGTH, ) @st.cache_data def load_metrics() -> dict | None: if not METRICS_PATH.exists(): return None import json with METRICS_PATH.open("r", encoding="utf-8") as fh: return json.load(fh) def format_probability(probability: float) -> str: return f"{probability * 100:.2f}%" def format_threshold(threshold: float) -> str: return f"{threshold * 100:.0f}%" def render_prediction_rows(predictions: list[dict[str, float | str]]) -> None: for index, item in enumerate(predictions, start=1): label = str(item["label"]) probability = float(item["probability"]) st.write(f"{index}. `{label}`") st.progress(min(max(probability, 0.0), 1.0), text=format_probability(probability)) def main() -> None: coverage_label = format_threshold(COVERAGE_THRESHOLD) st.title(str(APP_CONFIG["page_title"])) st.write( "This demo predicts arXiv paper topics from the title and abstract using a transformer classifier." ) st.caption( "For homework evaluation, the app returns the smallest prefix of categories whose cumulative " f"probability reaches {coverage_label}." ) st.info( "How to test: paste a real or synthetic paper title, optionally add an abstract, and press " "`Predict categories`. If the abstract is empty, the model will classify from the title only." ) classifier: ArticleClassifier | None = None classifier_load_error: str | None = None with st.sidebar: try: classifier = load_classifier() except Exception as exc: classifier_load_error = f"Model initialization error in load_classifier: {exc}" metrics = load_metrics() st.subheader("Evaluation Summary") st.write(f"Model: `{APP_CONFIG['model_name']}`") if classifier is not None: st.write(f"Number of classes: `{len(classifier.labels)}`") st.write("Classes: " + ", ".join(f"`{label}`" for label in classifier.labels)) else: st.error(classifier_load_error or "Model initialization error: unknown error") if metrics is not None: validation_accuracy = metrics.get("validation", {}).get("eval_accuracy") validation_f1 = metrics.get("validation", {}).get("eval_macro_f1") test_accuracy = metrics.get("test", {}).get("test_accuracy") test_f1 = metrics.get("test", {}).get("test_macro_f1") if validation_accuracy is not None: st.write(f"Validation accuracy: `{validation_accuracy:.4f}`") if validation_f1 is not None: st.write(f"Validation macro-F1: `{validation_f1:.4f}`") if test_accuracy is not None: st.write(f"Test accuracy: `{test_accuracy:.4f}`") if test_f1 is not None: st.write(f"Test macro-F1: `{test_f1:.4f}`") st.write( "Output rule: return categories until cumulative probability reaches " f"{coverage_label}" ) with st.expander("Example Input For Quick Check"): st.markdown( f"**Title:** {APP_CONFIG['example_title']}\n\n" f"**Abstract:** {APP_CONFIG['example_abstract']}" ) with st.form("prediction_form"): title = st.text_input( "Article title", placeholder="Enter the article title", ) abstract = st.text_area( "Abstract", placeholder="Enter the abstract (optional, but recommended)", height=220, ) predict_button = st.form_submit_button("Predict categories", type="primary") if predict_button: if classifier is None: st.error(classifier_load_error or "Model initialization error: classifier is unavailable.") return if not title.strip() and not abstract.strip(): st.error("Input validation error in app: please enter at least a title or an abstract.") return with st.spinner("Running inference..."): try: full_predictions = classifier.predict(title=title, abstract=abstract) predictions = classifier.select_top_k_by_probability_mass( full_predictions, threshold=COVERAGE_THRESHOLD, ) except ValueError as exc: st.error(str(exc)) return except ClassifierError as exc: st.error(f"Classifier error in prediction flow: {exc}") return except Exception as exc: st.error(f"Unexpected inference error in app.main: {exc}") return best_prediction = predictions[0] covered_probability = sum(float(item["probability"]) for item in predictions) col1, col2, col3 = st.columns(3) col1.metric("Top class", str(best_prediction["label"])) col2.metric("Top probability", format_probability(float(best_prediction["probability"]))) col3.metric("Top-95% coverage", format_probability(covered_probability)) st.subheader("Top categories") st.caption( f"These are the categories returned by the assignment top-{coverage_label} rule." ) render_prediction_rows(predictions) with st.expander("Show Full Ranking"): render_prediction_rows(full_predictions) if __name__ == "__main__": main()