import gradio as gr import torch import joblib import numpy as np from transformers import BertTokenizer, BertModel # ----------------- 1. Setup Device ----------------- # HF Spaces (Free) usually runs on CPU, but this keeps it robust device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # ----------------- 2. Load BERT ----------------- print("Loading BERT model...") tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') bert_model = BertModel.from_pretrained('bert-base-uncased') bert_model.to(device) bert_model.eval() # ----------------- 3. Load MLP + Scaler + LabelEncoder ----------------- # Ensure these files are uploaded to your HF Space Files tab! print("Loading classification components...") try: mlp = joblib.load("mlp_query_classifier.joblib") scaler = joblib.load("scaler_query_classifier.joblib") le = joblib.load("label_encoder_query_classifier.joblib") print("Loaded MLP, scaler, and label encoder.") except FileNotFoundError as e: print(f"Error: {e}. Please make sure you uploaded the .joblib files to the Space.") # ----------------- 4. Embedding Function ----------------- def get_bert_embeddings(text_list): inputs = tokenizer( text_list, padding=True, truncation=True, max_length=128, return_tensors="pt" ).to(device) with torch.no_grad(): outputs = bert_model(**inputs) cls_embeddings = outputs.last_hidden_state[:, 0, :] return cls_embeddings.cpu().numpy() # ----------------- 5. Prediction Function ----------------- def predict_new_query(text): # 1) BERT embedding embedding = get_bert_embeddings([text]) # 2) scale with same scaler as training embedding_scaled = scaler.transform(embedding) # 3) MLP prediction -> class index prediction_index = mlp.predict(embedding_scaled)[0] # 4) map index back to string label label = le.inverse_transform([prediction_index])[0] # Optional: Get probability if your MLP supports it try: probs = mlp.predict_proba(embedding_scaled)[0] confidence = np.max(probs) return f"Label: {label} (Confidence: {confidence:.2f})" except: return f"Label: {label}" # ----------------- 6. Launch Gradio Interface ----------------- # This creates the web UI iface = gr.Interface( fn=predict_new_query, inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."), outputs="text", title="BERT Query Classifier", description="Enter a text query to classify it using the custom BERT+MLP model." ) iface.launch()