|
|
import gradio as gr |
|
|
import torch |
|
|
import joblib |
|
|
import numpy as np |
|
|
from transformers import BertTokenizer, BertModel |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
print("Loading BERT model...") |
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
bert_model = BertModel.from_pretrained('bert-base-uncased') |
|
|
bert_model.to(device) |
|
|
bert_model.eval() |
|
|
|
|
|
|
|
|
|
|
|
print("Loading classification components...") |
|
|
try: |
|
|
mlp = joblib.load("mlp_query_classifier.joblib") |
|
|
scaler = joblib.load("scaler_query_classifier.joblib") |
|
|
le = joblib.load("label_encoder_query_classifier.joblib") |
|
|
print("Loaded MLP, scaler, and label encoder.") |
|
|
except FileNotFoundError as e: |
|
|
print(f"Error: {e}. Please make sure you uploaded the .joblib files to the Space.") |
|
|
|
|
|
|
|
|
def get_bert_embeddings(text_list): |
|
|
inputs = tokenizer( |
|
|
text_list, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=128, |
|
|
return_tensors="pt" |
|
|
).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = bert_model(**inputs) |
|
|
|
|
|
cls_embeddings = outputs.last_hidden_state[:, 0, :] |
|
|
return cls_embeddings.cpu().numpy() |
|
|
|
|
|
|
|
|
def predict_new_query(text): |
|
|
|
|
|
embedding = get_bert_embeddings([text]) |
|
|
|
|
|
|
|
|
embedding_scaled = scaler.transform(embedding) |
|
|
|
|
|
|
|
|
prediction_index = mlp.predict(embedding_scaled)[0] |
|
|
|
|
|
|
|
|
label = le.inverse_transform([prediction_index])[0] |
|
|
|
|
|
|
|
|
try: |
|
|
probs = mlp.predict_proba(embedding_scaled)[0] |
|
|
confidence = np.max(probs) |
|
|
return f"Label: {label} (Confidence: {confidence:.2f})" |
|
|
except: |
|
|
return f"Label: {label}" |
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict_new_query, |
|
|
inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."), |
|
|
outputs="text", |
|
|
title="BERT Query Classifier", |
|
|
description="Enter a text query to classify it using the custom BERT+MLP model." |
|
|
) |
|
|
|
|
|
iface.launch() |