File size: 7,325 Bytes
70b2ea0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from __future__ import annotations

import json
from pathlib import Path
from typing import Any

import streamlit as st

from inference import ArticleClassifier, ClassifierError


PROJECT_DIR = Path(__file__).resolve().parent
CONFIG_PATH = PROJECT_DIR / "configs" / "app_config.json"
METRICS_PATH = PROJECT_DIR / "artifacts" / "large_model" / "metrics.json"
DEFAULT_APP_CONFIG = {
    "model_dir": "artifacts/large_model/best_model",
    "labels_path": "data/processed_large/label_mapping.json",
    "max_length": 256,
    "coverage_threshold": 0.95,
    "model_name": "distilbert-base-uncased",
    "page_title": "arXiv Topic Classifier",
    "page_icon": "📚",
    "example_title": "Learning-based Visual Navigation for Mobile Robots",
    "example_abstract": (
        "We present a transformer-based navigation system that uses camera observations "
        "and scene understanding to plan robust trajectories for indoor mobile robots."
    ),
}


def load_app_config() -> dict[str, Any]:
    if not CONFIG_PATH.exists():
        return DEFAULT_APP_CONFIG.copy()

    with CONFIG_PATH.open("r", encoding="utf-8") as fh:
        config = json.load(fh)

    merged_config = DEFAULT_APP_CONFIG.copy()
    merged_config.update(config)
    return merged_config


APP_CONFIG = load_app_config()
MODEL_DIR = PROJECT_DIR / str(APP_CONFIG["model_dir"])
LABELS_PATH = PROJECT_DIR / str(APP_CONFIG["labels_path"])
MAX_LENGTH = int(APP_CONFIG["max_length"])
COVERAGE_THRESHOLD = float(APP_CONFIG["coverage_threshold"])


st.set_page_config(
    page_title=str(APP_CONFIG["page_title"]),
    page_icon=str(APP_CONFIG["page_icon"]),
    layout="centered",
)


@st.cache_resource
def load_classifier() -> ArticleClassifier:
    return ArticleClassifier(
        model_dir=MODEL_DIR,
        labels_path=LABELS_PATH,
        max_length=MAX_LENGTH,
    )


@st.cache_data
def load_metrics() -> dict | None:
    if not METRICS_PATH.exists():
        return None
    import json

    with METRICS_PATH.open("r", encoding="utf-8") as fh:
        return json.load(fh)


def format_probability(probability: float) -> str:
    return f"{probability * 100:.2f}%"


def format_threshold(threshold: float) -> str:
    return f"{threshold * 100:.0f}%"


def render_prediction_rows(predictions: list[dict[str, float | str]]) -> None:
    for index, item in enumerate(predictions, start=1):
        label = str(item["label"])
        probability = float(item["probability"])
        st.write(f"{index}. `{label}`")
        st.progress(min(max(probability, 0.0), 1.0), text=format_probability(probability))


def main() -> None:
    coverage_label = format_threshold(COVERAGE_THRESHOLD)

    st.title(str(APP_CONFIG["page_title"]))
    st.write(
        "This demo predicts arXiv paper topics from the title and abstract using a transformer classifier."
    )
    st.caption(
        "For homework evaluation, the app returns the smallest prefix of categories whose cumulative "
        f"probability reaches {coverage_label}."
    )
    st.info(
        "How to test: paste a real or synthetic paper title, optionally add an abstract, and press "
        "`Predict categories`. If the abstract is empty, the model will classify from the title only."
    )

    classifier: ArticleClassifier | None = None
    classifier_load_error: str | None = None

    with st.sidebar:
        try:
            classifier = load_classifier()
        except Exception as exc:
            classifier_load_error = f"Model initialization error in load_classifier: {exc}"

        metrics = load_metrics()
        st.subheader("Evaluation Summary")
        st.write(f"Model: `{APP_CONFIG['model_name']}`")
        if classifier is not None:
            st.write(f"Number of classes: `{len(classifier.labels)}`")
            st.write("Classes: " + ", ".join(f"`{label}`" for label in classifier.labels))
        else:
            st.error(classifier_load_error or "Model initialization error: unknown error")
        if metrics is not None:
            validation_accuracy = metrics.get("validation", {}).get("eval_accuracy")
            validation_f1 = metrics.get("validation", {}).get("eval_macro_f1")
            test_accuracy = metrics.get("test", {}).get("test_accuracy")
            test_f1 = metrics.get("test", {}).get("test_macro_f1")
            if validation_accuracy is not None:
                st.write(f"Validation accuracy: `{validation_accuracy:.4f}`")
            if validation_f1 is not None:
                st.write(f"Validation macro-F1: `{validation_f1:.4f}`")
            if test_accuracy is not None:
                st.write(f"Test accuracy: `{test_accuracy:.4f}`")
            if test_f1 is not None:
                st.write(f"Test macro-F1: `{test_f1:.4f}`")
        st.write(
            "Output rule: return categories until cumulative probability reaches "
            f"{coverage_label}"
        )

    with st.expander("Example Input For Quick Check"):
        st.markdown(
            f"**Title:** {APP_CONFIG['example_title']}\n\n"
            f"**Abstract:** {APP_CONFIG['example_abstract']}"
        )

    with st.form("prediction_form"):
        title = st.text_input(
            "Article title",
            placeholder="Enter the article title",
        )
        abstract = st.text_area(
            "Abstract",
            placeholder="Enter the abstract (optional, but recommended)",
            height=220,
        )
        predict_button = st.form_submit_button("Predict categories", type="primary")

    if predict_button:
        if classifier is None:
            st.error(classifier_load_error or "Model initialization error: classifier is unavailable.")
            return
        if not title.strip() and not abstract.strip():
            st.error("Input validation error in app: please enter at least a title or an abstract.")
            return

        with st.spinner("Running inference..."):
            try:
                full_predictions = classifier.predict(title=title, abstract=abstract)
                predictions = classifier.select_top_k_by_probability_mass(
                    full_predictions,
                    threshold=COVERAGE_THRESHOLD,
                )
            except ValueError as exc:
                st.error(str(exc))
                return
            except ClassifierError as exc:
                st.error(f"Classifier error in prediction flow: {exc}")
                return
            except Exception as exc:
                st.error(f"Unexpected inference error in app.main: {exc}")
                return

        best_prediction = predictions[0]
        covered_probability = sum(float(item["probability"]) for item in predictions)
        col1, col2, col3 = st.columns(3)
        col1.metric("Top class", str(best_prediction["label"]))
        col2.metric("Top probability", format_probability(float(best_prediction["probability"])))
        col3.metric("Top-95% coverage", format_probability(covered_probability))

        st.subheader("Top categories")
        st.caption(
            f"These are the categories returned by the assignment top-{coverage_label} rule."
        )
        render_prediction_rows(predictions)

        with st.expander("Show Full Ranking"):
            render_prediction_rows(full_predictions)


if __name__ == "__main__":
    main()