File size: 3,928 Bytes
9874885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e03aff
9874885
8e03aff
9874885
 
 
 
8e03aff
9874885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
use crate::data::{DartBatcher, DartDataset};
use crate::loss::diou_loss;
use crate::model::DartVisionModel;
use burn::data::dataset::Dataset; // Add this trait to scope
use burn::module::Module;
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
use burn::prelude::*;
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
use burn::tensor::backend::AutodiffBackend;

pub struct TrainingConfig {
    pub num_epochs: usize,
    pub batch_size: usize,
    pub lr: f64,
}

pub fn train<B: AutodiffBackend>(device: Device<B>, dataset_path: &str, config: TrainingConfig) {
    // 1. Create Model
    let mut model: DartVisionModel<B> = DartVisionModel::new(&device);

    // 1.5 Load existing weights if they exist (RESUME)
    let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
    let weights_path = "model_weights.bin";
    if std::path::Path::new(weights_path).exists() {
        println!("πŸš€ Loading existing weights from {}...", weights_path);
        let record = Recorder::load(&recorder, "model_weights".into(), &device)
            .expect("Failed to load weights");
        model = model.load_record(record);
    }

    // 2. Setup Optimizer
    let mut optim = AdamConfig::new().init();

    // 3. Create Dataset
    println!("πŸ” Mapping annotations from {}...", dataset_path);
    let dataset = DartDataset::load(dataset_path, "dataset/800");
    println!("πŸ“Š Dataset loaded with {} examples.", dataset.len());
    let batcher = DartBatcher::new(device.clone());

    // 4. Create DataLoader
    println!("πŸ“¦ Initializing DataLoader (Workers: 4)...");
    let dataloader = burn::data::dataloader::DataLoaderBuilder::new(batcher)
        .batch_size(config.batch_size)
        .shuffle(42)
        .num_workers(4)
        .build(dataset);

    // 5. Training Loop
    println!(
        "πŸ“ˆ Running FULL Training Loop (Epochs: {})...",
        config.num_epochs
    );

    // Using a simple loop state for ownership safety
    let mut current_model = model;

    for epoch in 1..=config.num_epochs {
        let mut model_inner = current_model; // Move into epoch
        let mut batch_count = 0;

        for batch in dataloader.iter() {
            // Forward Pass
            let (out16, _) = model_inner.forward(batch.images);

            // Calculate Loss
            let loss = diou_loss(out16, batch.targets);
            batch_count += 1;

            // Print every 20 batches β€” use detach() to avoid cloning the full autodiff graph
            if batch_count % 20 == 0 || batch_count == 1 {
                let loss_val = loss.clone().detach().into_scalar();
                println!(
                    "   [Epoch {}] Batch {: >3} | Loss: {:.6}",
                    epoch,
                    batch_count,
                    loss_val
                );
            }

            // Backward & Optimization step
            let grads = loss.backward();
            let grads_params = GradientsParams::from_grads(grads, &model_inner);
            model_inner = optim.step(config.lr, model_inner, grads_params);

            // 5.5 Periodic Save (every 100 batches and Batch 1)
            if batch_count % 100 == 0 || batch_count == 1 {
                model_inner.clone()
                    .save_file("model_weights", &recorder)
                    .ok();
                if batch_count == 1 {
                    println!("πŸš€ [Checkpoint] Initial weights saved at Batch 1.");
                } else {
                    println!("πŸš€ [Checkpoint] Saved at Batch {}.", batch_count);
                }
            }
        }

        // 6. SAVE after EACH Epoch
        model_inner
            .clone()
            .save_file("model_weights", &recorder)
            .expect("Failed to save weights");
        println!("βœ… Checkpoint saved: Epoch {} complete.", epoch);

        current_model = model_inner; // Move back out for next epoch
    }
}