Spaces:
Build error
Build error
kapil
feat: implement core RustAutoScoreEngine framework including data loading, model architecture, training loop, inference, and GUI server
8e03aff | use crate::data::{DartBatcher, DartDataset}; | |
| use crate::loss::diou_loss; | |
| use crate::model::DartVisionModel; | |
| use burn::data::dataset::Dataset; // Add this trait to scope | |
| use burn::module::Module; | |
| use burn::optim::{AdamConfig, GradientsParams, Optimizer}; | |
| use burn::prelude::*; | |
| use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; | |
| use burn::tensor::backend::AutodiffBackend; | |
| pub struct TrainingConfig { | |
| pub num_epochs: usize, | |
| pub batch_size: usize, | |
| pub lr: f64, | |
| } | |
| pub fn train<B: AutodiffBackend>(device: Device<B>, dataset_path: &str, config: TrainingConfig) { | |
| // 1. Create Model | |
| let mut model: DartVisionModel<B> = DartVisionModel::new(&device); | |
| // 1.5 Load existing weights if they exist (RESUME) | |
| let recorder = BinFileRecorder::<FullPrecisionSettings>::default(); | |
| let weights_path = "model_weights.bin"; | |
| if std::path::Path::new(weights_path).exists() { | |
| println!("π Loading existing weights from {}...", weights_path); | |
| let record = Recorder::load(&recorder, "model_weights".into(), &device) | |
| .expect("Failed to load weights"); | |
| model = model.load_record(record); | |
| } | |
| // 2. Setup Optimizer | |
| let mut optim = AdamConfig::new().init(); | |
| // 3. Create Dataset | |
| println!("π Mapping annotations from {}...", dataset_path); | |
| let dataset = DartDataset::load(dataset_path, "dataset/800"); | |
| println!("π Dataset loaded with {} examples.", dataset.len()); | |
| let batcher = DartBatcher::new(device.clone()); | |
| // 4. Create DataLoader | |
| println!("π¦ Initializing DataLoader (Workers: 4)..."); | |
| let dataloader = burn::data::dataloader::DataLoaderBuilder::new(batcher) | |
| .batch_size(config.batch_size) | |
| .shuffle(42) | |
| .num_workers(4) | |
| .build(dataset); | |
| // 5. Training Loop | |
| println!( | |
| "π Running FULL Training Loop (Epochs: {})...", | |
| config.num_epochs | |
| ); | |
| // Using a simple loop state for ownership safety | |
| let mut current_model = model; | |
| for epoch in 1..=config.num_epochs { | |
| let mut model_inner = current_model; // Move into epoch | |
| let mut batch_count = 0; | |
| for batch in dataloader.iter() { | |
| // Forward Pass | |
| let (out16, _) = model_inner.forward(batch.images); | |
| // Calculate Loss | |
| let loss = diou_loss(out16, batch.targets); | |
| batch_count += 1; | |
| // Print every 20 batches β use detach() to avoid cloning the full autodiff graph | |
| if batch_count % 20 == 0 || batch_count == 1 { | |
| let loss_val = loss.clone().detach().into_scalar(); | |
| println!( | |
| " [Epoch {}] Batch {: >3} | Loss: {:.6}", | |
| epoch, | |
| batch_count, | |
| loss_val | |
| ); | |
| } | |
| // Backward & Optimization step | |
| let grads = loss.backward(); | |
| let grads_params = GradientsParams::from_grads(grads, &model_inner); | |
| model_inner = optim.step(config.lr, model_inner, grads_params); | |
| // 5.5 Periodic Save (every 100 batches and Batch 1) | |
| if batch_count % 100 == 0 || batch_count == 1 { | |
| model_inner.clone() | |
| .save_file("model_weights", &recorder) | |
| .ok(); | |
| if batch_count == 1 { | |
| println!("π [Checkpoint] Initial weights saved at Batch 1."); | |
| } else { | |
| println!("π [Checkpoint] Saved at Batch {}.", batch_count); | |
| } | |
| } | |
| } | |
| // 6. SAVE after EACH Epoch | |
| model_inner | |
| .clone() | |
| .save_file("model_weights", &recorder) | |
| .expect("Failed to save weights"); | |
| println!("β Checkpoint saved: Epoch {} complete.", epoch); | |
| current_model = model_inner; // Move back out for next epoch | |
| } | |
| } | |