joko333's picture
Refactor model loading function to handle checkpoints and improve error handling
3681591
raw
history blame
3.56 kB
from utils.model import BiLSTMAttentionBERT, BiLSTMConfig
import torch
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder
import numpy as np
import streamlit as st
import requests
from huggingface_hub import hf_hub_download
def load_model_for_prediction():
try:
st.write("Starting model loading...")
config = BiLSTMConfig(
hidden_dim=128,
num_classes=22,
num_layers=2,
dropout=0.5
)
# Initialize model
model = BiLSTMAttentionBERT(config)
# Load checkpoint
model_path = hf_hub_download(
repo_id="joko333/BiLSTM_v01",
filename="model_epoch8_acc72.53.pt"
)
checkpoint = torch.load(model_path, map_location='cpu')
# Extract model state dict from checkpoint
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
model.load_state_dict(state_dict)
st.write("Model loaded successfully")
else:
st.error("Invalid checkpoint format")
return None, None, None
# Initialize label encoder from checkpoint
label_encoder = LabelEncoder()
if 'label_encoder_classes' in checkpoint:
label_encoder.classes_ = checkpoint['label_encoder_classes']
else:
st.error("Label encoder data not found in checkpoint")
return None, None, None
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
return model, label_encoder, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None, None
def predict_sentence(model, sentence, tokenizer, label_encoder):
"""
Make prediction for a single sentence with label validation.
"""
# Validation checks
if model is None:
print("Error: Model not loaded")
return "Error: Model not loaded", 0.0
if tokenizer is None:
print("Error: Tokenizer not loaded")
return "Error: Tokenizer not loaded", 0.0
if label_encoder is None:
print("Error: Label encoder not loaded")
return "Error: Label encoder not loaded", 0.0
# Force CPU device
device = torch.device('cpu')
model = model.to(device)
model.eval()
# Tokenize
try:
encoding = tokenizer(
sentence,
add_special_tokens=True,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(device)
with torch.no_grad():
outputs = model(encoding['input_ids'], encoding['attention_mask'])
probabilities = torch.softmax(outputs, dim=1)
prob, pred_idx = torch.max(probabilities, dim=1)
predicted_label = label_encoder.classes_[pred_idx.item()]
return predicted_label, prob.item()
except Exception as e:
print(f"Prediction error: {str(e)}")
return f"Error: {str(e)}", 0.0
def print_labels(label_encoder, show_counts=False):
"""Print all labels and their corresponding indices"""
print("\nAvailable labels:")
print("-" * 40)
for idx, label in enumerate(label_encoder.classes_):
print(f"Index {idx}: {label}")
print("-" * 40)
print(f"Total number of classes: {len(label_encoder.classes_)}\n")