Spaces:
Sleeping
Sleeping
| import logging | |
| import torch.nn as nn | |
| from torch.optim import AdamW | |
| from src.config import DEVICE, EPOCHS, NUM_CLASSES | |
| from src.models.fusion_model import FusionClassifier | |
| from src.data.dataset import create_fusion_dataloaders | |
| from src.training.trainer import train_dual_input_model | |
| logger = logging.getLogger(__name__) | |
| def run_fusion_training(): | |
| logger.info("Initializing Fusion training pipeline...") | |
| train_loader, eval_loader = create_fusion_dataloaders() | |
| model = FusionClassifier( | |
| num_classes=NUM_CLASSES | |
| ).to(DEVICE) | |
| criterion = nn.CrossEntropyLoss( | |
| label_smoothing=0.1 | |
| ) | |
| optimizer = AdamW([ | |
| # EfficientNet unfrozen blocks | |
| { | |
| "params": model.eff_features[5].parameters(), | |
| "lr": 1e-5 | |
| }, | |
| { | |
| "params": model.eff_features[6].parameters(), | |
| "lr": 3e-5 | |
| }, | |
| { | |
| "params": model.eff_features[7].parameters(), | |
| "lr": 3e-5 | |
| }, | |
| # ConvNeXt unfrozen blocks | |
| { | |
| "params": model.cnx_backbone.encoder.stages[2].parameters(), | |
| "lr": 3e-5 | |
| }, | |
| { | |
| "params": model.cnx_backbone.encoder.stages[3].parameters(), | |
| "lr": 3e-5 | |
| }, | |
| { | |
| "params": model.cnx_backbone.layernorm.parameters(), | |
| "lr": 3e-5 | |
| }, | |
| # Fusion head | |
| { | |
| "params": model.fusion_head.parameters(), | |
| "lr": 1e-4 | |
| } | |
| ], weight_decay=1e-4) | |
| logger.info("Starting Fusion training...") | |
| all_preds, all_labels = train_dual_input_model( | |
| model=model, | |
| train_loader=train_loader, | |
| eval_loader=eval_loader, | |
| optimizer=optimizer, | |
| criterion=criterion, | |
| device=DEVICE, | |
| epochs=EPOCHS, | |
| checkpoint_model_name="best_fusion_model", | |
| patience=7 | |
| ) | |
| logger.info("Fusion training completed.") | |
| return all_preds, all_labels | |
| if __name__ == "__main__": | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| preds, labels = run_fusion_training() | |
| print("\nFusion training completed successfully.") | |
| print("Prediction samples:", preds[:10]) | |
| print("Label samples:", labels[:10]) |