hatamo's picture
Initial deployment of Antique Authenticity API
718c4ae
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from model import AuctionAuthenticityModel
from dataset_loader import AuctionDatasetFromJSON, get_transforms
import json
def train_epoch(model, loader, optimizer, device, epoch):
model.train()
total_loss = 0
progress_bar = tqdm(loader, desc=f"Epoch {epoch} [TRAIN]")
for batch in progress_bar:
images = batch['image'].to(device)
texts = batch['text']
labels = batch['label'].to(device)
optimizer.zero_grad()
logits = model(images, texts)
loss = F.cross_entropy(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix(loss=f'{loss.item():.4f}')
return total_loss / len(loader)
def validate(model, loader, device, epoch):
model.eval()
all_preds = []
all_labels = []
total_loss = 0
with torch.no_grad():
progress_bar = tqdm(loader, desc=f"Epoch {epoch} [VAL]")
for batch in progress_bar:
images = batch['image'].to(device)
texts = batch['text']
labels = batch['label'].to(device)
logits = model(images, texts)
loss = F.cross_entropy(logits, labels)
total_loss += loss.item()
preds = torch.argmax(logits, dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
acc = accuracy_score(all_labels, all_preds)
prec = precision_score(all_labels, all_preds, zero_division=0)
rec = recall_score(all_labels, all_preds, zero_division=0)
f1 = f1_score(all_labels, all_preds, zero_division=0)
return {
'loss': total_loss / len(loader),
'accuracy': acc,
'precision': prec,
'recall': rec,
'f1': f1
}
def main():
# Konfiguracja
BATCH_SIZE = 4
EPOCHS = 5
LEARNING_RATE = 2e-5
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🖥️ Device: {DEVICE}")
print(f"📦 Batch size: {BATCH_SIZE}")
print(f"📚 Epochs: {EPOCHS}")
# Załaduj dataset
print("\n📥 Ładowanie datasetu...")
dataset = AuctionDatasetFromJSON(
json_path='../dataset/dataset.json',
root_dir='../dataset/raw_data',
transform=get_transforms()
)
print(f"✓ {len(dataset)} aukcji załadowanych")
# Split: 80% train, 20% val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
print(f" - Train: {len(train_dataset)}")
print(f" - Val: {len(val_dataset)}")
# Model
print("\n🧠 Inicjalizacja modelu...")
model = AuctionAuthenticityModel(device=DEVICE).to(DEVICE)
print(f"✓ Model gotowy ({model.count_parameters():,} parametrów)")
# Optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
# Training loop
print("\n🚀 Rozpoczynam trening...\n")
history = {
'train_loss': [],
'val_loss': [],
'val_accuracy': [],
'val_f1': []
}
for epoch in range(EPOCHS):
# Train
train_loss = train_epoch(model, train_loader, optimizer, DEVICE, epoch+1)
# Validate
val_metrics = validate(model, val_loader, DEVICE, epoch+1)
# Log
history['train_loss'].append(train_loss)
history['val_loss'].append(val_metrics['loss'])
history['val_accuracy'].append(val_metrics['accuracy'])
history['val_f1'].append(val_metrics['f1'])
print(f"\n{'='*60}")
print(f"Epoch {epoch+1}/{EPOCHS}")
print(f" Train Loss: {train_loss:.4f}")
print(f" Val Loss: {val_metrics['loss']:.4f}")
print(f" Val Acc: {val_metrics['accuracy']:.4f}")
print(f" Val Prec: {val_metrics['precision']:.4f}")
print(f" Val Rec: {val_metrics['recall']:.4f}")
print(f" Val F1: {val_metrics['f1']:.4f}")
print(f"{'='*60}\n")
# Zapis modelu
print("\n💾 Zapis modelu...")
torch.save(model.state_dict(), '../weights/auction_model.pt')
print("✓ Zapisano: weights/auction_model.pt")
# Zapis historii
with open('../weights/training_history.json', 'w') as f:
json.dump(history, f, indent=2)
print("✓ Zapisano: weights/training_history.json")
print("\n✅ Trening ukończony!")
if __name__ == '__main__':
main()