itsjhuang commited on
Commit
eb4875c
·
verified ·
1 Parent(s): 19eefec

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio demo for the Watsonx Docs Type Classifier.
3
+ Loads the best trained model from models/ and serves predictions.
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+ import joblib
10
+ import numpy as np
11
+ from sentence_transformers import SentenceTransformer
12
+
13
+ LABELS = ["conceptual", "how-to"]
14
+
15
+ model_name = (Path("models") / "best_model_name.txt").read_text().strip()
16
+ embedder = SentenceTransformer(model_name)
17
+ clf = joblib.load(Path("models") / "best_model.joblib")
18
+
19
+
20
+ def softmax(x):
21
+ e = np.exp(x - np.max(x))
22
+ return e / e.sum()
23
+
24
+
25
+ def predict(text: str) -> dict:
26
+ if not text.strip():
27
+ return {label: 0.0 for label in LABELS}
28
+ embedding = embedder.encode([text], convert_to_numpy=True)
29
+ if hasattr(clf, "predict_proba"):
30
+ probs = clf.predict_proba(embedding)[0]
31
+ else:
32
+ scores = clf.decision_function(embedding)[0]
33
+ # LinearSVC returns a scalar for binary; wrap in array
34
+ if np.ndim(scores) == 0:
35
+ scores = np.array([-scores, scores])
36
+ probs = softmax(scores)
37
+ return {label: float(p) for label, p in zip(LABELS, probs)}
38
+
39
+
40
+ demo = gr.Interface(
41
+ fn=predict,
42
+ inputs=gr.Textbox(
43
+ label="Document text (title + body)",
44
+ lines=8,
45
+ placeholder="Paste the title and opening text of a Watsonx documentation page here.",
46
+ ),
47
+ outputs=gr.Label(num_top_classes=2, label="Predicted document type"),
48
+ title="Watsonx Docs Type Classifier",
49
+ description=(
50
+ "Predicts whether a Watsonx documentation page is **conceptual** or **how-to**. "
51
+ "Paste the page title and opening text below."
52
+ ),
53
+ )
54
+
55
+ if __name__ == "__main__":
56
+ demo.launch(share=False)