File size: 4,038 Bytes
cac9916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
import os
import json
from transformers import LongformerModel, AutoModel, LongformerTokenizerFast, AutoTokenizer, PreTrainedModel

class HarmFormer(PreTrainedModel):
    def __init__(self, config):
        super(HarmFormer, self).__init__(config)
        self.num_classes = config.num_classes
        self.num_risk_levels = config.num_risk_levels
        
        # Base model
        self.base_model = AutoModel.from_config(config)

        # Classification heads
        hidden_size = self.base_model.config.hidden_size
        
        self.classifiers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, 128),
                nn.ReLU(),
                nn.Linear(128, self.num_risk_levels)
            )
            for _ in range(self.num_classes)
        ])
    
    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]  # Pooled [CLS] token output
        
        # Apply classifiers for each task
        logits = []
        for classifier in self.classifiers:
            logits.append(classifier(pooled_output))
        
        return logits
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        # Load config
        config_path = os.path.join(pretrained_model_name_or_path, "config.json")
        if os.path.exists(config_path):
            with open(config_path, 'r') as f:
                model_config = json.load(f)
        else:
            # Try to load from HF Hub
            from huggingface_hub import hf_hub_download
            config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="config.json")
            with open(config_path, 'r') as f:
                model_config = json.load(f)
        
        # Create base model config
        from transformers import AutoConfig
        base_model_name = model_config.get("model_name", "allenai/longformer-base-4096")
        base_config = AutoConfig.from_pretrained(base_model_name)
        
        # Add our custom attributes
        base_config.num_classes = model_config.get("num_classes", 5)
        base_config.num_risk_levels = model_config.get("num_risk_levels", 3)
        base_config.architecture = model_config.get("architecture", "SingleFC")
        
        # Create model
        model = cls(base_config)
        
        # Load weights
        checkpoint_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
        if os.path.exists(checkpoint_path):
            state_dict = torch.load(checkpoint_path, map_location="cpu")
        else:
            # Try to load from HF Hub
            checkpoint_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="pytorch_model.bin")
            state_dict = torch.load(checkpoint_path, map_location="cpu")
        
        model.load_state_dict(state_dict)
        model.eval()
        
        return model

def predict_batch(model, tokenizer, texts, batch_size=32):
    device = next(model.parameters()).device
    predictions = []
    
    # Process in batches to avoid OOM
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        inputs = tokenizer(
            batch_texts,
            add_special_tokens=True,
            max_length=1024,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        ).to(device)
        
        with torch.no_grad():
            outputs = model(**inputs)
            logits = torch.stack(outputs, dim=0).permute(1, 0, 2)  # (batch_size, num_classes, num_risk_levels)
            probs = torch.softmax(logits, dim=-1)
            batch_preds = [[[round(prob, 3) for prob in class_probs] for class_probs in sample] for sample in probs.cpu().tolist()]
            predictions.extend(batch_preds)
    
    return predictions