# model.py import torch import torch.nn as nn from transformers import DistilBertTokenizer, DistilBertModel from torchvision.models import efficientnet_b0 class AuctionAuthenticityModel(nn.Module): def __init__(self, num_classes=3, device='cpu'): # 3 klasy! super().__init__() self.device = device # Vision self.vision_model = efficientnet_b0(pretrained=True) self.vision_model.classifier = nn.Identity() vision_out_dim = 1280 # Text self.text_model = DistilBertModel.from_pretrained( 'distilbert-base-multilingual-cased' ) text_out_dim = 768 self.tokenizer = DistilBertTokenizer.from_pretrained( 'distilbert-base-multilingual-cased' ) # Fusion (bez BatchNorm!) hidden_dim = 256 self.fusion = nn.Sequential( nn.Linear(vision_out_dim + text_out_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.3), nn.Linear(hidden_dim, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, num_classes) ) def forward(self, images, texts): vision_features = self.vision_model(images) tokens = self.tokenizer( texts, padding=True, truncation=True, max_length=512, return_tensors='pt' ).to(self.device) text_outputs = self.text_model(**tokens) text_features = text_outputs.last_hidden_state[:, 0, :] combined = torch.cat([vision_features, text_features], dim=1) logits = self.fusion(combined) return logits def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) if __name__ == '__main__': print("Testowanie modelu...") device = torch.device('cpu') model = AuctionAuthenticityModel(device=device).to(device) print(f"✓ Model stworzony") print(f" - Parametrów: {model.count_parameters():,}") # Dummy test dummy_img = torch.randn(2, 3, 224, 224).to(device) dummy_texts = ["Silver spoon antique", "Polish silverware 19th century"] with torch.no_grad(): output = model(dummy_img, dummy_texts) print(f"✓ Forward pass: {output.shape}") print(f" - Output: {output}") # Estimate model size print(f"\n📊 Rozmiar modelu:") torch.save(model.state_dict(), 'temp_model.pt') import os size_mb = os.path.getsize('temp_model.pt') / (1024*1024) print(f" - {size_mb:.1f} MB") os.remove('temp_model.pt')