newreyy's picture
feat (inference): support for multilabel inference
ca8ae50 verified
import os
import re
import gradio as gr
import pandas as pd
import tempfile
from collections import OrderedDict
from transformers import (
pipeline,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
)
from huggingface_hub import list_models
# =====================================================
# CONFIG
# =====================================================
MAX_CACHE = 2
DEFAULT_THRESHOLD = 0.5
# =====================================================
# PIPELINE CACHE (LRU)
# =====================================================
PIPELINE_CACHE = OrderedDict()
def reset_pipeline_cache():
PIPELINE_CACHE.clear()
def get_pipeline(model_name: str, mode: str = "binary"):
if not model_name or model_name.startswith("("):
raise ValueError("Invalid model name")
token = (
os.getenv("HF_TOKEN")
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
or None
)
cache_key = f"{model_name}::{mode}"
if cache_key in PIPELINE_CACHE:
PIPELINE_CACHE.move_to_end(cache_key)
return PIPELINE_CACHE[cache_key]
config = AutoConfig.from_pretrained(model_name, token=token)
# Paksa ke multilabel jika mode multilabel
if mode == "multilabel":
config.problem_type = "multi_label_classification"
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
token=token
)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
p = pipeline(
task="text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
padding=True,
max_length=512,
token=token
)
PIPELINE_CACHE[cache_key] = p
PIPELINE_CACHE.move_to_end(cache_key)
while len(PIPELINE_CACHE) > MAX_CACHE:
PIPELINE_CACHE.popitem(last=False)
return p
# =====================================================
# LOAD MODELS
# =====================================================
def load_user_models(username: str):
if not username.strip():
return (
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
"❌ Username required"
)
reset_pipeline_cache()
models = list_models(author=username.strip())
model_ids = sorted([m.modelId for m in models])
if not model_ids:
return (
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
"⚠️ No models found"
)
first = model_ids[0]
second = model_ids[1] if len(model_ids) > 1 else model_ids[0]
return (
gr.update(choices=model_ids, value=first),
gr.update(choices=model_ids, value=second),
gr.update(choices=model_ids, value=first),
gr.update(choices=model_ids, value=second),
f"✅ {len(model_ids)} models loaded"
)
# =====================================================
# TEXT CLEANING
# =====================================================
def clean_text(text):
if not isinstance(text, str):
return ""
text = re.sub(r"https?://\S+|www\.\S+", "<link>", text)
text = re.sub(r"\b[\w\.-]+@[\w\.-]+\.\w+\b", "<email>", text)
text = re.sub(r"@\w+", "<user>", text)
text = text.replace("#", "").replace('"', "").replace("'", "")
text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text).strip()
return text
# =====================================================
# INFERENCE HELPER
# =====================================================
def run_inference(pipe, inputs, mode="binary", batch_size=None):
kwargs = {}
if mode == "multilabel":
kwargs["top_k"] = None
kwargs["function_to_apply"] = "sigmoid"
else:
# Ambil semua juga tidak masalah, nanti postprocess pilih yang terbaik
kwargs["top_k"] = None
if batch_size is not None:
kwargs["batch_size"] = batch_size
return pipe(inputs, **kwargs)
# =====================================================
# POSTPROCESS
# =====================================================
def postprocess(preds, mode="binary", threshold=0.5):
"""
Normalisasi output pipeline:
- single binary -> dict / list[dict]
- single multilabel-> list[dict]
- batch binary -> list[list[dict]]
- batch multilabel -> list[list[dict]]
"""
if isinstance(preds, dict):
preds = [[preds]]
elif isinstance(preds, list) and len(preds) > 0 and isinstance(preds[0], dict):
preds = [preds]
outputs = []
for sample_preds in preds:
if mode == "binary":
best = max(sample_preds, key=lambda x: x["score"])
outputs.append({
"label": best["label"],
"score": round(float(best["score"]), 6)
})
else:
filtered = [
{
"label": x["label"],
"score": round(float(x["score"]), 6)
}
for x in sample_preds
if float(x["score"]) >= threshold
]
outputs.append(filtered)
return outputs
# =====================================================
# SINGLE TEXT
# =====================================================
def compare_single(text, model_a, model_b, mode, threshold):
if not text.strip():
return {"error": "Empty input"}, {"error": "Empty input"}
cleaned = clean_text(text)
pipe_a = get_pipeline(model_a, mode)
pipe_b = get_pipeline(model_b, mode)
pred_a = run_inference(pipe_a, cleaned, mode=mode)
pred_b = run_inference(pipe_b, cleaned, mode=mode)
r1 = postprocess(pred_a, mode, threshold)[0]
r2 = postprocess(pred_b, mode, threshold)[0]
return (
{"cleaned_text": cleaned, "prediction": r1},
{"cleaned_text": cleaned, "prediction": r2}
)
# =====================================================
# BATCH CSV (COMPARE)
# =====================================================
def batch_compare_csv(
file,
text_column,
model_a,
model_b,
mode,
threshold,
batch_size
):
if file is None:
return {"error": "CSV not uploaded"}, None
df = pd.read_csv(file.name)
if text_column not in df.columns:
return {"error": f"Column '{text_column}' not found"}, None
texts = df[text_column].astype(str).apply(clean_text).tolist()
pipe_a = get_pipeline(model_a, mode)
pipe_b = get_pipeline(model_b, mode)
preds_a = run_inference(pipe_a, texts, mode=mode, batch_size=batch_size)
preds_b = run_inference(pipe_b, texts, mode=mode, batch_size=batch_size)
res_a = postprocess(preds_a, mode, threshold)
res_b = postprocess(preds_b, mode, threshold)
if mode == "binary":
df["label_model_a"] = [x["label"] for x in res_a]
df["conf_model_a"] = [x["score"] for x in res_a]
df["label_model_b"] = [x["label"] for x in res_b]
df["conf_model_b"] = [x["score"] for x in res_b]
else:
df["labels_model_a"] = [[x["label"] for x in row] for row in res_a]
df["scores_model_a"] = [[x["score"] for x in row] for row in res_a]
df["labels_model_b"] = [[x["label"] for x in row] for row in res_b]
df["scores_model_b"] = [[x["score"] for x in row] for row in res_b]
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
df.to_csv(tmp.name, index=False)
return df.head(10).to_dict(orient="records"), tmp.name
# =====================================================
# UI
# =====================================================
with gr.Blocks(
title="NLP Model Evaluation Platform",
theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate")
) as demo:
gr.Markdown("# NLP Model Evaluation Platform")
hf_user = gr.Textbox(label="HuggingFace Username")
load_btn = gr.Button("Load Models", variant="primary")
status = gr.Markdown("")
mode = gr.Radio(
["binary", "multilabel"],
value="binary",
label="Classification Mode"
)
threshold = gr.Slider(
0.1, 0.9,
value=0.5,
step=0.05,
label="Multilabel Threshold"
)
gr.Markdown("## Single Text Comparison")
text = gr.Textbox(lines=4, label="Input Text")
with gr.Row():
model_a = gr.Dropdown(label="Baseline Model")
model_b = gr.Dropdown(label="Candidate Model")
compare_btn = gr.Button("Compare Models", variant="primary")
with gr.Row():
out_a = gr.JSON(label="Baseline Output")
out_b = gr.JSON(label="Candidate Output")
gr.Markdown("---")
gr.Markdown("## Batch CSV Comparison")
csv_file = gr.File(file_types=[".csv"])
text_col = gr.Textbox(label="Text Column Name")
with gr.Row():
batch_model_a = gr.Dropdown(label="Baseline Model")
batch_model_b = gr.Dropdown(label="Candidate Model")
batch_size = gr.Slider(1, 64, value=16, step=1, label="Batch Size")
run_batch = gr.Button("Run Batch Compare", variant="primary")
preview = gr.JSON(label="Preview (First 10 Rows)")
download = gr.File(label="Download CSV")
load_btn.click(
load_user_models,
hf_user,
[model_a, model_b, batch_model_a, batch_model_b, status]
)
compare_btn.click(
compare_single,
[text, model_a, model_b, mode, threshold],
[out_a, out_b]
)
run_batch.click(
batch_compare_csv,
[csv_file, text_col, batch_model_a, batch_model_b, mode, threshold, batch_size],
[preview, download]
)
# =====================================================
# LAUNCH
# =====================================================
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
ssr_mode=False
)