Spaces:
Sleeping
Sleeping
File size: 4,596 Bytes
1b17d16 ca5c473 42d8a45 1b17d16 ca5c473 42d8a45 1b17d16 42d8a45 ca5c473 621f6b2 42d8a45 ca5c473 42d8a45 ca5c473 42d8a45 ca5c473 42d8a45 ca5c473 42d8a45 ca5c473 621f6b2 ca5c473 04dc908 ca5c473 04dc908 ca5c473 04dc908 ca5c473 04dc908 ca5c473 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from utils.model import BiLSTMAttentionBERT, BiLSTMConfig
import torch
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder
import numpy as np
import streamlit as st
import requests
from huggingface_hub import hf_hub_download
def load_model_for_prediction():
try:
st.write("Starting model loading...")
config = BiLSTMConfig(
hidden_dim=128,
num_classes=22,
num_layers=2,
dropout=0.5
)
model = BiLSTMAttentionBERT(config)
model_path = hf_hub_download(
repo_id="joko333/BiLSTM_v01",
filename="model_epoch8_acc72.53.pt"
)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict)
# Test Hugging Face connectivity
st.write("Testing connection to Hugging Face...")
response = requests.get("https://huggingface.co/joko333/BiLSTM_v01")
if response.status_code != 200:
st.error(f"Cannot connect to Hugging Face. Status code: {response.status_code}")
return None, None, None
# Load model with logging
st.write("Loading BiLSTM model...")
model = BiLSTMAttentionBERT.from_pretrained(
"joko333/BiLSTM_v01",
hidden_dim=128,
num_classes=22,
num_layers=2,
dropout=0.5
)
st.write("Model loaded successfully")
# Initialize label encoder
st.write("Initializing label encoder...")
label_encoder = LabelEncoder()
label_encoder.classes_ = np.array(['Addition', 'Causal', 'Cause and Effect',
'Clarification', 'Comparison', 'Concession',
'Conditional', 'Contrast', 'Contrastive Emphasis',
'Definition', 'Elaboration', 'Emphasis',
'Enumeration', 'Explanation', 'Generalization',
'Illustration', 'Inference', 'Problem Solution',
'Purpose', 'Sequential', 'Summary',
'Temporal Sequence'])
st.write("Label encoder initialized")
# Load tokenizer
st.write("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
st.write("Tokenizer loaded successfully")
return model, label_encoder, tokenizer
except Exception as e:
st.error(f"Detailed error: {str(e)}")
st.error(f"Error type: {type(e).__name__}")
import traceback
st.error(f"Traceback: {traceback.format_exc()}")
return None, None, None
def predict_sentence(model, sentence, tokenizer, label_encoder):
"""
Make prediction for a single sentence with label validation.
"""
# Validation checks
if model is None:
print("Error: Model not loaded")
return "Error: Model not loaded", 0.0
if tokenizer is None:
print("Error: Tokenizer not loaded")
return "Error: Tokenizer not loaded", 0.0
if label_encoder is None:
print("Error: Label encoder not loaded")
return "Error: Label encoder not loaded", 0.0
# Force CPU device
device = torch.device('cpu')
model = model.to(device)
model.eval()
# Tokenize
try:
encoding = tokenizer(
sentence,
add_special_tokens=True,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(device)
with torch.no_grad():
outputs = model(encoding['input_ids'], encoding['attention_mask'])
probabilities = torch.softmax(outputs, dim=1)
prob, pred_idx = torch.max(probabilities, dim=1)
predicted_label = label_encoder.classes_[pred_idx.item()]
return predicted_label, prob.item()
except Exception as e:
print(f"Prediction error: {str(e)}")
return f"Error: {str(e)}", 0.0
def print_labels(label_encoder, show_counts=False):
"""Print all labels and their corresponding indices"""
print("\nAvailable labels:")
print("-" * 40)
for idx, label in enumerate(label_encoder.classes_):
print(f"Index {idx}: {label}")
print("-" * 40)
print(f"Total number of classes: {len(label_encoder.classes_)}\n")
|