File size: 2,602 Bytes
718c4ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# model.py
import torch
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel
from torchvision.models import efficientnet_b0

class AuctionAuthenticityModel(nn.Module):
    def __init__(self, num_classes=3, device='cpu'):  # 3 klasy!
        super().__init__()
        self.device = device
        
        # Vision
        self.vision_model = efficientnet_b0(pretrained=True)
        self.vision_model.classifier = nn.Identity()
        vision_out_dim = 1280
        
        # Text
        self.text_model = DistilBertModel.from_pretrained(
            'distilbert-base-multilingual-cased'
        )
        text_out_dim = 768
        
        self.tokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-multilingual-cased'
        )
        
        # Fusion (bez BatchNorm!)
        hidden_dim = 256
        self.fusion = nn.Sequential(
            nn.Linear(vision_out_dim + text_out_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, images, texts):
        vision_features = self.vision_model(images)
        tokens = self.tokenizer(
            texts, padding=True, truncation=True, max_length=512, return_tensors='pt'
        ).to(self.device)
        text_outputs = self.text_model(**tokens)
        text_features = text_outputs.last_hidden_state[:, 0, :]
        
        combined = torch.cat([vision_features, text_features], dim=1)
        logits = self.fusion(combined)
        return logits
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


if __name__ == '__main__':
    print("Testowanie modelu...")
    
    device = torch.device('cpu')
    model = AuctionAuthenticityModel(device=device).to(device)
    
    print(f"✓ Model stworzony")
    print(f"  - Parametrów: {model.count_parameters():,}")
    
    # Dummy test
    dummy_img = torch.randn(2, 3, 224, 224).to(device)
    dummy_texts = ["Silver spoon antique", "Polish silverware 19th century"]
    
    with torch.no_grad():
        output = model(dummy_img, dummy_texts)
    
    print(f"✓ Forward pass: {output.shape}")
    print(f"  - Output: {output}")
    
    # Estimate model size
    print(f"\n📊 Rozmiar modelu:")
    torch.save(model.state_dict(), 'temp_model.pt')
    import os
    size_mb = os.path.getsize('temp_model.pt') / (1024*1024)
    print(f"  - {size_mb:.1f} MB")
    os.remove('temp_model.pt')