File size: 5,626 Bytes
beb558a
 
 
0b24aad
beb558a
 
 
0b24aad
beb558a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318a267
beb558a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee52ba9
beb558a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b24aad
 
beb558a
0b24aad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
from tensorflow import keras
from sentence_transformers import SentenceTransformer

# -----------------------
# Config
# -----------------------
REPO_ID = "Bocklitz-Lab/lit2vec-subfield-classifier-model"
EMBED_MODEL = "intfloat/e5-large-v2"
TEXT_PREFIX = {"abstract": "abstract: ", "summary": "summary: "}
DEFAULT_THRESHOLD = 0.5
TOPK_DEFAULT = 5

# -----------------------
# Load model + labels at startup
# -----------------------
# Keras model (saved as .h5 on the Hub)
MODEL_PATH = hf_hub_download(REPO_ID, filename="mlp_model.h5")
LABEL_MAP_PATH = hf_hub_download(REPO_ID, filename="label_mapping.json")

with open(LABEL_MAP_PATH, "r", encoding="utf-8") as f:
    mapping = json.load(f)
INDEX_TO_LABEL = {int(k): v for k, v in mapping["index_to_label"].items()}

# load Keras model for inference
MODEL = keras.models.load_model(MODEL_PATH, compile=False)

# SentenceTransformer encoder (CPU-only for portability)
ENCODER = SentenceTransformer(EMBED_MODEL, device="cpu")

def encode_text(text: str, text_type: str = "abstract") -> np.ndarray:
    """Encode text into normalized embedding compatible with the classifier."""
    prefix = TEXT_PREFIX.get(text_type, "")
    emb = ENCODER.encode([prefix + text], normalize_embeddings=True)  # shape: (1, D)
    return emb.astype("float32")

def predict(text: str, text_type: str, threshold: float, topk: int):
    """Return selected labels (by threshold), top-k labels, and a scores table."""
    text = (text or "").strip()
    if not text:
        return ("", "", [])

    X = encode_text(text, text_type=text_type)  # (1, D)
    probs = MODEL.predict(X, verbose=0)[0]      # (18,)

    # Thresholded predictions
    pred_ids = [i for i, p in enumerate(probs) if p >= threshold]
    pred_labels = [INDEX_TO_LABEL[i] for i in pred_ids]
    pred_display = ", ".join(pred_labels) if pred_labels else "—"

    # Top-k predictions (by score)
    topk = max(1, int(topk))
    order = np.argsort(-probs)[:topk]
    topk_items = [f"{INDEX_TO_LABEL[i]}: {probs[i]:.3f}" for i in order]
    topk_display = "\n".join(topk_items)

    # Build a table of all scores (sorted desc)
    sorted_ids = np.argsort(-probs)
    table = [[INDEX_TO_LABEL[i], float(probs[i])] for i in sorted_ids]

    return pred_display, topk_display, table

# -----------------------
# Gradio UI
# -----------------------
with gr.Blocks(fill_height=True) as demo:
    gr.Markdown(
        """
        # Lit2Vec Subfield Classifier
        Enter a **chemistry abstract or summary**. The app encodes it with `e5-large-v2` and predicts one or more **subfields** using the MLP model.

        **Model:** `Bocklitz-Lab/lit2vec-subfield-classifier-model`  
        **Encoder:** `intfloat/e5-large-v2`
        """
    )

    with gr.Row():
        text_type = gr.Radio(
            choices=["abstract", "summary"], value="abstract", label="Text type (prefix used for encoding)"
        )
        threshold = gr.Slider(0.0, 1.0, value=DEFAULT_THRESHOLD, step=0.01, label="Decision threshold")
        topk = gr.Slider(1, 10, value=TOPK_DEFAULT, step=1, label="Top-K to display")

    input_box = gr.Textbox(
        label="Paste abstract / summary",
        placeholder="Paste your chemistry abstract here…",
        lines=12,
        value="Ultraviolet B (UVB; 290~320nm) irradiation-induced lipid peroxidation induces inflammatory responses that lead to skin wrinkle formation and epidermal thickening. Peroxisome proliferator-activated receptor (PPAR) α/γ dual agonists have the potential to be used as anti-wrinkle agents because they inhibit inflammatory response and lipid peroxidation. In this study, we evaluated the function of 2-bromo-4-(5-chloro-benzo[d]thiazol-2-yl) phenol (MHY 966), a novel synthetic PPAR α/γ dual agonist, and investigated its anti-inflammatory and anti-lipid peroxidation effects. The action of MHY 966 as a PPAR α/γ dual agonist was also determined in vitro by reporter gene assay. Additionally, 8-week-old melanin-possessing hairless mice 2 (HRM2) were exposed to 150 mJ/cm2 UVB every other day for 17 days and MHY 966 was simultaneously pre-treated every day for 17 days to investigate the molecular mechanisms involved. MHY 966 was found to stimulate the transcriptional activities of both PPAR α and γ. In HRM2 mice, we found that the skins of mice exposed to UVB showed significantly increased pro-inflammatory mediator levels (NF-κB, iNOS, and COX-2) and increased lipid peroxidation, whereas MHY 966 co-treatment down-regulated these effects of UVB by activating PPAR α and γ. Thus, the present study shows that MHY 966 exhibits beneficial effects on inflammatory responses and lipid peroxidation by simultaneously activating PPAR α and γ. The major finding of this study is that MHY 966 demonstrates potential as an agent against wrinkle formation associated with chronic UVB exposure."
    )

    run_btn = gr.Button("Predict subfield(s)")

    with gr.Row():
        selected_labels = gr.Textbox(label="Predicted fields (thresholded)", lines=2)
        topk_labels = gr.Textbox(label="Top-K (scores)", lines=6)

    scores_table = gr.Dataframe(
        headers=["Subfield", "Score"],
        datatype=["str", "number"],
        label="All scores (sorted)",
        interactive=False
    )

    run_btn.click(
        fn=predict,
        inputs=[input_box, text_type, threshold, topk],
        outputs=[selected_labels, topk_labels, scores_table]
    )

if __name__ == "__main__":
    # On Spaces, Gradio sets host/port; keep defaults.
    demo.launch()