File size: 5,885 Bytes
ef70c53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Complete working script to load ConceptFrameMet from HuggingFace with ALL weights.
This properly reconstructs the source_qa_model from checkpoint weights.
"""

from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
from transformers import RobertaModel, RobertaTokenizer, RobertaForSequenceClassification, RobertaConfig
import sys
import os

# Download files
print("Downloading from HuggingFace...")
weights_path = hf_hub_download("nixie1981/ConceptFrameMet", "pytorch_model.bin")
labels_path = hf_hub_download("nixie1981/ConceptFrameMet", "source_labels.json")

# Load checkpoint
print("Loading checkpoint...")
state_dict = torch.load(weights_path, map_location='cpu')

print(f"Checkpoint has {len(state_dict)} keys")

# Check what's in the checkpoint
has_source_qa = any(k.startswith('source_qa_model.') for k in state_dict.keys())
print(f"Has source_qa_model weights: {has_source_qa}")

if has_source_qa:
    # Count source_qa_model keys
    source_keys = [k for k in state_dict.keys() if k.startswith('source_qa_model.')]
    print(f"Source QA model has {len(source_keys)} keys")
    
    # Extract source_qa_model architecture from keys
    # Looking for: source_qa_model.roberta.*, source_qa_model.frame_finder.*, source_qa_model.source_classifier.*
    has_frame_finder = any('frame_finder' in k for k in source_keys)
    has_source_classifier = any('source_classifier' in k for k in source_keys)
    
    print(f"  - Has frame_finder: {has_frame_finder}")
    print(f"  - Has source_classifier: {has_source_classifier}")
    
    if has_frame_finder and has_source_classifier:
        print("\nThis is a TrueMultiTaskModel (frame + source)!")
        print("Creating source_qa_model structure...")
        
        # Get num_frames and num_sources from checkpoint
        frame_weight_key = 'source_qa_model.frame_finder.classifier.out_proj.weight'
        source_weight_key = 'source_qa_model.source_classifier.weight'
        
        num_frames = state_dict[frame_weight_key].shape[0] if frame_weight_key in state_dict else None
        num_sources = state_dict[source_weight_key].shape[0] if source_weight_key in state_dict else None
        
        print(f"  - num_frames: {num_frames}")
        print(f"  - num_sources: {num_sources}")
        
        if num_frames and num_sources:
            # CREATE the source_qa_model structure!
            config = RobertaConfig.from_pretrained('roberta-base')
            
            # Check actual source_classifier shape from checkpoint
            source_classifier_weight = state_dict.get('source_qa_model.source_classifier.weight')
            source_classifier_input_size = source_classifier_weight.shape[1] if source_classifier_weight is not None else None
            
            print(f"  - source_classifier input size: {source_classifier_input_size}")
            
            class TrueMultiTaskModel(nn.Module):
                def __init__(self, config, num_frames, num_sources, source_input_size):
                    super().__init__()
                    self.config = config
                    self.num_frames = num_frames
                    self.num_sources = num_sources
                    
                    self.roberta = RobertaModel(config)
                    self.frame_finder = RobertaForSequenceClassification(config)
                    self.frame_finder.classifier = nn.Linear(config.hidden_size, num_frames)
                    
                    # Source classifier - use actual size from checkpoint
                    self.dropout = nn.Dropout(config.hidden_dropout_prob)
                    self.source_classifier = nn.Linear(source_input_size, num_sources)
                
                def forward(self, input_ids=None, attention_mask=None, 
                           frame_input_ids=None, frame_attention_mask=None, **kwargs):
                    # Frame prediction
                    frame_outputs = self.frame_finder(input_ids=frame_input_ids,
                                                     attention_mask=frame_attention_mask)
                    frame_logits = frame_outputs.logits
                    
                    # Source prediction
                    if input_ids is not None:
                        source_outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
                        pooled_output = source_outputs.pooler_output
                        combined = torch.cat([pooled_output, frame_logits], dim=1)
                        combined = self.dropout(combined)
                        logits = self.source_classifier(combined)
                        
                        class Output:
                            pass
                        output = Output()
                        output.logits = logits
                        return output
                    
                    class Output:
                        pass
                    output = Output()
                    output.logits = frame_logits
                    return output
            
            # Create and load
            source_qa_model = TrueMultiTaskModel(config, num_frames, num_sources, source_classifier_input_size)
            
            # Extract source_qa_model weights
            source_state_dict = {}
            for k, v in state_dict.items():
                if k.startswith('source_qa_model.'):
                    new_key = k.replace('source_qa_model.', '')
                    source_state_dict[new_key] = v
            
            # Load weights
            missing, unexpected = source_qa_model.load_state_dict(source_state_dict, strict=False)
            print(f"\nLoaded source_qa_model: missing={len(missing)}, unexpected={len(unexpected)}")
            
            print("\n✅ SOURCE_QA_MODEL CREATED AND LOADED!")
            print("Now the full model will work correctly!")