import torch import torch.nn.functional as F from torch.optim import AdamW from torch.utils.data import DataLoader, random_split from tqdm import tqdm from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score from model import AuctionAuthenticityModel from dataset_loader import AuctionDatasetFromJSON, get_transforms import json def train_epoch(model, loader, optimizer, device, epoch): model.train() total_loss = 0 progress_bar = tqdm(loader, desc=f"Epoch {epoch} [TRAIN]") for batch in progress_bar: images = batch['image'].to(device) texts = batch['text'] labels = batch['label'].to(device) optimizer.zero_grad() logits = model(images, texts) loss = F.cross_entropy(logits, labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() progress_bar.set_postfix(loss=f'{loss.item():.4f}') return total_loss / len(loader) def validate(model, loader, device, epoch): model.eval() all_preds = [] all_labels = [] total_loss = 0 with torch.no_grad(): progress_bar = tqdm(loader, desc=f"Epoch {epoch} [VAL]") for batch in progress_bar: images = batch['image'].to(device) texts = batch['text'] labels = batch['label'].to(device) logits = model(images, texts) loss = F.cross_entropy(logits, labels) total_loss += loss.item() preds = torch.argmax(logits, dim=1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy()) acc = accuracy_score(all_labels, all_preds) prec = precision_score(all_labels, all_preds, zero_division=0) rec = recall_score(all_labels, all_preds, zero_division=0) f1 = f1_score(all_labels, all_preds, zero_division=0) return { 'loss': total_loss / len(loader), 'accuracy': acc, 'precision': prec, 'recall': rec, 'f1': f1 } def main(): # Konfiguracja BATCH_SIZE = 4 EPOCHS = 5 LEARNING_RATE = 2e-5 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"🖥️ Device: {DEVICE}") print(f"📦 Batch size: {BATCH_SIZE}") print(f"📚 Epochs: {EPOCHS}") # Załaduj dataset print("\n📥 Ładowanie datasetu...") dataset = AuctionDatasetFromJSON( json_path='../dataset/dataset.json', root_dir='../dataset/raw_data', transform=get_transforms() ) print(f"✓ {len(dataset)} aukcji załadowanych") # Split: 80% train, 20% val train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) print(f" - Train: {len(train_dataset)}") print(f" - Val: {len(val_dataset)}") # Model print("\n🧠 Inicjalizacja modelu...") model = AuctionAuthenticityModel(device=DEVICE).to(DEVICE) print(f"✓ Model gotowy ({model.count_parameters():,} parametrów)") # Optimizer optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) # Training loop print("\n🚀 Rozpoczynam trening...\n") history = { 'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_f1': [] } for epoch in range(EPOCHS): # Train train_loss = train_epoch(model, train_loader, optimizer, DEVICE, epoch+1) # Validate val_metrics = validate(model, val_loader, DEVICE, epoch+1) # Log history['train_loss'].append(train_loss) history['val_loss'].append(val_metrics['loss']) history['val_accuracy'].append(val_metrics['accuracy']) history['val_f1'].append(val_metrics['f1']) print(f"\n{'='*60}") print(f"Epoch {epoch+1}/{EPOCHS}") print(f" Train Loss: {train_loss:.4f}") print(f" Val Loss: {val_metrics['loss']:.4f}") print(f" Val Acc: {val_metrics['accuracy']:.4f}") print(f" Val Prec: {val_metrics['precision']:.4f}") print(f" Val Rec: {val_metrics['recall']:.4f}") print(f" Val F1: {val_metrics['f1']:.4f}") print(f"{'='*60}\n") # Zapis modelu print("\n💾 Zapis modelu...") torch.save(model.state_dict(), '../weights/auction_model.pt') print("✓ Zapisano: weights/auction_model.pt") # Zapis historii with open('../weights/training_history.json', 'w') as f: json.dump(history, f, indent=2) print("✓ Zapisano: weights/training_history.json") print("\n✅ Trening ukończony!") if __name__ == '__main__': main()