kapil
feat: implement core RustAutoScoreEngine framework including data loading, model architecture, training loop, inference, and GUI server
8e03aff
use crate::data::{DartBatcher, DartDataset};
use crate::loss::diou_loss;
use crate::model::DartVisionModel;
use burn::data::dataset::Dataset; // Add this trait to scope
use burn::module::Module;
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
use burn::prelude::*;
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
use burn::tensor::backend::AutodiffBackend;
pub struct TrainingConfig {
pub num_epochs: usize,
pub batch_size: usize,
pub lr: f64,
}
pub fn train<B: AutodiffBackend>(device: Device<B>, dataset_path: &str, config: TrainingConfig) {
// 1. Create Model
let mut model: DartVisionModel<B> = DartVisionModel::new(&device);
// 1.5 Load existing weights if they exist (RESUME)
let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
let weights_path = "model_weights.bin";
if std::path::Path::new(weights_path).exists() {
println!("πŸš€ Loading existing weights from {}...", weights_path);
let record = Recorder::load(&recorder, "model_weights".into(), &device)
.expect("Failed to load weights");
model = model.load_record(record);
}
// 2. Setup Optimizer
let mut optim = AdamConfig::new().init();
// 3. Create Dataset
println!("πŸ” Mapping annotations from {}...", dataset_path);
let dataset = DartDataset::load(dataset_path, "dataset/800");
println!("πŸ“Š Dataset loaded with {} examples.", dataset.len());
let batcher = DartBatcher::new(device.clone());
// 4. Create DataLoader
println!("πŸ“¦ Initializing DataLoader (Workers: 4)...");
let dataloader = burn::data::dataloader::DataLoaderBuilder::new(batcher)
.batch_size(config.batch_size)
.shuffle(42)
.num_workers(4)
.build(dataset);
// 5. Training Loop
println!(
"πŸ“ˆ Running FULL Training Loop (Epochs: {})...",
config.num_epochs
);
// Using a simple loop state for ownership safety
let mut current_model = model;
for epoch in 1..=config.num_epochs {
let mut model_inner = current_model; // Move into epoch
let mut batch_count = 0;
for batch in dataloader.iter() {
// Forward Pass
let (out16, _) = model_inner.forward(batch.images);
// Calculate Loss
let loss = diou_loss(out16, batch.targets);
batch_count += 1;
// Print every 20 batches β€” use detach() to avoid cloning the full autodiff graph
if batch_count % 20 == 0 || batch_count == 1 {
let loss_val = loss.clone().detach().into_scalar();
println!(
" [Epoch {}] Batch {: >3} | Loss: {:.6}",
epoch,
batch_count,
loss_val
);
}
// Backward & Optimization step
let grads = loss.backward();
let grads_params = GradientsParams::from_grads(grads, &model_inner);
model_inner = optim.step(config.lr, model_inner, grads_params);
// 5.5 Periodic Save (every 100 batches and Batch 1)
if batch_count % 100 == 0 || batch_count == 1 {
model_inner.clone()
.save_file("model_weights", &recorder)
.ok();
if batch_count == 1 {
println!("πŸš€ [Checkpoint] Initial weights saved at Batch 1.");
} else {
println!("πŸš€ [Checkpoint] Saved at Batch {}.", batch_count);
}
}
}
// 6. SAVE after EACH Epoch
model_inner
.clone()
.save_file("model_weights", &recorder)
.expect("Failed to save weights");
println!("βœ… Checkpoint saved: Epoch {} complete.", epoch);
current_model = model_inner; // Move back out for next epoch
}
}