hatamo's picture
Initial deployment of Antique Authenticity API
718c4ae
# model.py
import torch
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel
from torchvision.models import efficientnet_b0
class AuctionAuthenticityModel(nn.Module):
def __init__(self, num_classes=3, device='cpu'): # 3 klasy!
super().__init__()
self.device = device
# Vision
self.vision_model = efficientnet_b0(pretrained=True)
self.vision_model.classifier = nn.Identity()
vision_out_dim = 1280
# Text
self.text_model = DistilBertModel.from_pretrained(
'distilbert-base-multilingual-cased'
)
text_out_dim = 768
self.tokenizer = DistilBertTokenizer.from_pretrained(
'distilbert-base-multilingual-cased'
)
# Fusion (bez BatchNorm!)
hidden_dim = 256
self.fusion = nn.Sequential(
nn.Linear(vision_out_dim + text_out_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
def forward(self, images, texts):
vision_features = self.vision_model(images)
tokens = self.tokenizer(
texts, padding=True, truncation=True, max_length=512, return_tensors='pt'
).to(self.device)
text_outputs = self.text_model(**tokens)
text_features = text_outputs.last_hidden_state[:, 0, :]
combined = torch.cat([vision_features, text_features], dim=1)
logits = self.fusion(combined)
return logits
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
if __name__ == '__main__':
print("Testowanie modelu...")
device = torch.device('cpu')
model = AuctionAuthenticityModel(device=device).to(device)
print(f"✓ Model stworzony")
print(f" - Parametrów: {model.count_parameters():,}")
# Dummy test
dummy_img = torch.randn(2, 3, 224, 224).to(device)
dummy_texts = ["Silver spoon antique", "Polish silverware 19th century"]
with torch.no_grad():
output = model(dummy_img, dummy_texts)
print(f"✓ Forward pass: {output.shape}")
print(f" - Output: {output}")
# Estimate model size
print(f"\n📊 Rozmiar modelu:")
torch.save(model.state_dict(), 'temp_model.pt')
import os
size_mb = os.path.getsize('temp_model.pt') / (1024*1024)
print(f" - {size_mb:.1f} MB")
os.remove('temp_model.pt')