File size: 4,125 Bytes
1b17d16
ca5c473
67bf242
ca5c473
 
42d8a45
 
1b17d16
ca5c473
 
 
 
 
42d8a45
67bf242
 
 
 
 
1b17d16
 
 
 
 
 
 
 
67bf242
3681591
67bf242
1b17d16
 
 
 
3681591
42d8a45
67bf242
 
 
3681591
67bf242
 
3681591
67bf242
 
 
 
 
 
3681591
 
 
42d8a45
 
3681591
ca5c473
3681591
 
 
 
 
 
42d8a45
 
ca5c473
 
 
 
3681591
ca5c473
 
621f6b2
ca5c473
 
 
04dc908
 
 
 
 
 
 
 
 
 
 
 
 
 
ca5c473
 
 
 
04dc908
 
 
 
 
 
 
 
 
ca5c473
 
 
 
04dc908
 
ca5c473
 
 
04dc908
ca5c473
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from utils.model import BiLSTMAttentionBERT, BiLSTMConfig
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import LabelEncoder
import numpy as np
import streamlit as st
import requests
from huggingface_hub import hf_hub_download



def load_model_for_prediction():
    try:
        st.write("Starting model loading...")
        
        # Initialize BERT first
        bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
        
        # Initialize config and model
        config = BiLSTMConfig(
            hidden_dim=128,
            num_classes=22,
            num_layers=2,
            dropout=0.5
        )
        
        model = BiLSTMAttentionBERT(config)
        model.bert = bert  # Set pre-trained BERT
        
        # Load custom layers from checkpoint
        model_path = hf_hub_download(
            repo_id="joko333/BiLSTM_v01",
            filename="model_epoch8_acc72.53.pt"
        )
        checkpoint = torch.load(model_path, map_location='cpu')
        
        # Debug checkpoint structure
        st.write("Checkpoint keys:", checkpoint.keys())
        
        if 'model_state_dict' in checkpoint:
            # Extract only custom layer weights
            custom_state_dict = {}
            state_dict = checkpoint['model_state_dict']
            for key, value in state_dict.items():
                if not key.startswith('bert.'):
                    custom_state_dict[key] = value
            
            # Load custom layers
            model.load_state_dict(custom_state_dict, strict=False)
            st.write("Model loaded successfully")
        else:
            st.error("Invalid checkpoint format")
            return None, None, None
            
        # Initialize label encoder from checkpoint
        label_encoder = LabelEncoder()
        if 'label_encoder_classes' in checkpoint:
            label_encoder.classes_ = checkpoint['label_encoder_classes']
        else:
            st.error("Label encoder data not found in checkpoint")
            return None, None, None
            
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
        
        return model, label_encoder, tokenizer
        
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None, None, None

def predict_sentence(model, sentence, tokenizer, label_encoder):
    """
    Make prediction for a single sentence with label validation.
    """
    # Validation checks
    if model is None:
        print("Error: Model not loaded")
        return "Error: Model not loaded", 0.0
    if tokenizer is None:
        print("Error: Tokenizer not loaded")
        return "Error: Tokenizer not loaded", 0.0
    if label_encoder is None:
        print("Error: Label encoder not loaded")
        return "Error: Label encoder not loaded", 0.0
        
    # Force CPU device
    device = torch.device('cpu')
    model = model.to(device)
    model.eval()
    
    # Tokenize
    try:
        encoding = tokenizer(
            sentence,
            add_special_tokens=True,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(device)
        
        with torch.no_grad():
            outputs = model(encoding['input_ids'], encoding['attention_mask'])
            probabilities = torch.softmax(outputs, dim=1)
            prob, pred_idx = torch.max(probabilities, dim=1)
            predicted_label = label_encoder.classes_[pred_idx.item()]
            return predicted_label, prob.item()
            
    except Exception as e:
        print(f"Prediction error: {str(e)}")
        return f"Error: {str(e)}", 0.0
    
def print_labels(label_encoder, show_counts=False):
    """Print all labels and their corresponding indices"""
    print("\nAvailable labels:")
    print("-" * 40)
    for idx, label in enumerate(label_encoder.classes_):
        print(f"Index {idx}: {label}")
    print("-" * 40)
    print(f"Total number of classes: {len(label_encoder.classes_)}\n")