flowerfy / src /training /simple_train.py
Toy
Apply code formatting and fix compatibility issues
b24c04f
"""
Simple ConvNeXt training script without using the Transformers Trainer class.
Refactored version of the original simple_train.py
"""
import json
import os
import torch
from torch.utils.data import DataLoader
from transformers import ConvNextForImageClassification, ConvNextImageProcessor
from ..core.config import config
from ..core.constants import DEFAULT_CONVNEXT_MODEL, MODELS_DIR
from ..services.training.dataset import FlowerDataset
def simple_train():
"""Simple ConvNeXt training function."""
print("🌸 Simple ConvNeXt Flower Model Training")
print("=" * 40)
# Check training data
images_dir = "training_data/images"
if not os.path.exists(images_dir):
print("❌ Training directory not found")
return
device = config.device
print(f"Using device: {device}")
# Load model and processor
model_name = DEFAULT_CONVNEXT_MODEL
model = ConvNextForImageClassification.from_pretrained(model_name)
processor = ConvNextImageProcessor.from_pretrained(model_name)
model.to(device)
# Create dataset
dataset = FlowerDataset(images_dir, processor)
if len(dataset) < 5:
print("❌ Need at least 5 images for training")
return
# Update model config for the number of classes
if len(dataset.flower_labels) != model.config.num_labels:
model.config.num_labels = len(dataset.flower_labels)
# ConvNeXt uses hidden_sizes[-1] as the final hidden dimension
final_hidden_size = (
model.config.hidden_sizes[-1]
if hasattr(model.config, "hidden_sizes")
else 768
)
model.classifier = torch.nn.Linear(
final_hidden_size, len(dataset.flower_labels)
)
# Split dataset
train_size = int(0.8 * len(dataset))
train_dataset = torch.utils.data.Subset(dataset, range(train_size))
# Create data loader
def simple_collate_fn(batch):
pixel_values = []
labels = []
for item in batch:
pixel_values.append(item["pixel_values"])
labels.append(item["labels"])
return {
"pixel_values": torch.stack(pixel_values),
"labels": torch.stack(labels),
}
train_loader = DataLoader(
train_dataset, batch_size=4, shuffle=True, collate_fn=simple_collate_fn
)
# Setup optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# Training loop
model.train()
print(f"Starting training on {len(train_dataset)} samples...")
for epoch in range(3):
total_loss = 0
num_batches = 0
for batch_idx, batch in enumerate(train_loader):
# Move to device
pixel_values = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
outputs = model(pixel_values=pixel_values, labels=labels)
loss = outputs.loss
# Backward pass
loss.backward()
optimizer.step()
total_loss += loss.item()
num_batches += 1
if batch_idx % 2 == 0:
print(
f"Epoch {epoch + 1}, Batch {batch_idx + 1}: Loss = {loss.item():.4f}"
)
avg_loss = total_loss / num_batches if num_batches > 0 else 0
print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")
# Save model
output_dir = os.path.join(MODELS_DIR, "simple_trained_convnext")
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
# Save config
config_data = {
"model_name": model_name,
"flower_labels": dataset.flower_labels,
"num_epochs": 3,
"batch_size": 4,
"learning_rate": 1e-5,
"train_samples": len(train_dataset),
"num_labels": len(dataset.flower_labels),
}
with open(os.path.join(output_dir, "training_config.json"), "w") as f:
json.dump(config_data, f, indent=2)
print(f"✅ ConvNeXt training completed! Model saved to {output_dir}")
return output_dir
if __name__ == "__main__":
try:
simple_train()
except KeyboardInterrupt:
print("\n⚠️ Training interrupted by user.")
except Exception as e:
print(f"❌ Training failed: {e}")
import traceback
traceback.print_exc()