| """ |
| Simple ConvNeXt training script without using the Transformers Trainer class. |
| Refactored version of the original simple_train.py |
| """ |
|
|
| import json |
| import os |
|
|
| import torch |
| from torch.utils.data import DataLoader |
| from transformers import ConvNextForImageClassification, ConvNextImageProcessor |
|
|
| from ..core.config import config |
| from ..core.constants import DEFAULT_CONVNEXT_MODEL, MODELS_DIR |
| from ..services.training.dataset import FlowerDataset |
|
|
|
|
| def simple_train(): |
| """Simple ConvNeXt training function.""" |
| print("🌸 Simple ConvNeXt Flower Model Training") |
| print("=" * 40) |
|
|
| |
| images_dir = "training_data/images" |
| if not os.path.exists(images_dir): |
| print("❌ Training directory not found") |
| return |
|
|
| device = config.device |
| print(f"Using device: {device}") |
|
|
| |
| model_name = DEFAULT_CONVNEXT_MODEL |
| model = ConvNextForImageClassification.from_pretrained(model_name) |
| processor = ConvNextImageProcessor.from_pretrained(model_name) |
| model.to(device) |
|
|
| |
| dataset = FlowerDataset(images_dir, processor) |
|
|
| if len(dataset) < 5: |
| print("❌ Need at least 5 images for training") |
| return |
|
|
| |
| if len(dataset.flower_labels) != model.config.num_labels: |
| model.config.num_labels = len(dataset.flower_labels) |
| |
| final_hidden_size = ( |
| model.config.hidden_sizes[-1] |
| if hasattr(model.config, "hidden_sizes") |
| else 768 |
| ) |
| model.classifier = torch.nn.Linear( |
| final_hidden_size, len(dataset.flower_labels) |
| ) |
|
|
| |
| train_size = int(0.8 * len(dataset)) |
| train_dataset = torch.utils.data.Subset(dataset, range(train_size)) |
|
|
| |
| def simple_collate_fn(batch): |
| pixel_values = [] |
| labels = [] |
|
|
| for item in batch: |
| pixel_values.append(item["pixel_values"]) |
| labels.append(item["labels"]) |
|
|
| return { |
| "pixel_values": torch.stack(pixel_values), |
| "labels": torch.stack(labels), |
| } |
|
|
| train_loader = DataLoader( |
| train_dataset, batch_size=4, shuffle=True, collate_fn=simple_collate_fn |
| ) |
|
|
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) |
|
|
| |
| model.train() |
| print(f"Starting training on {len(train_dataset)} samples...") |
|
|
| for epoch in range(3): |
| total_loss = 0 |
| num_batches = 0 |
|
|
| for batch_idx, batch in enumerate(train_loader): |
| |
| pixel_values = batch["pixel_values"].to(device) |
| labels = batch["labels"].to(device) |
|
|
| |
| optimizer.zero_grad() |
|
|
| |
| outputs = model(pixel_values=pixel_values, labels=labels) |
| loss = outputs.loss |
|
|
| |
| loss.backward() |
| optimizer.step() |
|
|
| total_loss += loss.item() |
| num_batches += 1 |
|
|
| if batch_idx % 2 == 0: |
| print( |
| f"Epoch {epoch + 1}, Batch {batch_idx + 1}: Loss = {loss.item():.4f}" |
| ) |
|
|
| avg_loss = total_loss / num_batches if num_batches > 0 else 0 |
| print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}") |
|
|
| |
| output_dir = os.path.join(MODELS_DIR, "simple_trained_convnext") |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| model.save_pretrained(output_dir) |
| processor.save_pretrained(output_dir) |
|
|
| |
| config_data = { |
| "model_name": model_name, |
| "flower_labels": dataset.flower_labels, |
| "num_epochs": 3, |
| "batch_size": 4, |
| "learning_rate": 1e-5, |
| "train_samples": len(train_dataset), |
| "num_labels": len(dataset.flower_labels), |
| } |
|
|
| with open(os.path.join(output_dir, "training_config.json"), "w") as f: |
| json.dump(config_data, f, indent=2) |
|
|
| print(f"✅ ConvNeXt training completed! Model saved to {output_dir}") |
| return output_dir |
|
|
|
|
| if __name__ == "__main__": |
| try: |
| simple_train() |
| except KeyboardInterrupt: |
| print("\n⚠️ Training interrupted by user.") |
| except Exception as e: |
| print(f"❌ Training failed: {e}") |
| import traceback |
|
|
| traceback.print_exc() |
|
|