import torch from torch import nn, optim from torch.utils.data import DataLoader, Dataset from torchvision import transforms, datasets, models from PIL import Image import json import os import gradio as gr import shutil # Paths image_folder = "Images/" metadata_file = "descriptions.json" # Define the function to load metadata def load_metadata(metadata_file): with open(metadata_file, 'r') as f: metadata = json.load(f) return metadata # Custom Dataset Class class ImageDescriptionDataset(Dataset): def __init__(self, image_folder, metadata): self.image_folder = image_folder self.metadata = metadata self.image_names = list(metadata.keys()) # List of image filenames self.transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) def __len__(self): return len(self.image_names) def __getitem__(self, idx): image_name = self.image_names[idx] image_path = os.path.join(self.image_folder, image_name) image = Image.open(image_path).convert("RGB") description = self.metadata[image_name] image = self.transform(image) return image, description # LoRA Layer Implementation class LoRALayer(nn.Module): def __init__(self, original_layer, rank=4): super(LoRALayer, self).__init__() self.original_layer = original_layer self.rank = rank self.lora_up = nn.Linear(original_layer.in_features, rank, bias=False) self.lora_down = nn.Linear(rank, original_layer.out_features, bias=False) def forward(self, x): return self.original_layer(x) + self.lora_down(self.lora_up(x)) # LoRA Model Class class LoRAModel(nn.Module): def __init__(self): super(LoRAModel, self).__init__() self.backbone = models.resnet18(pretrained=True) # Base model self.backbone.fc = LoRALayer(self.backbone.fc) # Replace the final layer with LoRA def forward(self, x): return self.backbone(x) # Training Function def train_lora(image_folder, metadata): print("Starting LoRA training process...") # Create dataset and dataloader dataset = ImageDescriptionDataset(image_folder, metadata) dataloader = DataLoader(dataset, batch_size=8, shuffle=True) # Initialize model, loss function, and optimizer model = LoRAModel() criterion = nn.CrossEntropyLoss() # Update this if your task changes optimizer = optim.Adam(model.parameters(), lr=0.001) # Training loop num_epochs = 5 # Adjust the number of epochs based on your needs for epoch in range(num_epochs): print(f"Epoch {epoch + 1}/{num_epochs}") for batch_idx, (images, descriptions) in enumerate(dataloader): # Convert descriptions to a numerical format (if applicable) labels = torch.randint(0, 100, (images.size(0),)) # Placeholder labels # Forward pass outputs = model(images) loss = criterion(outputs, labels) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % 10 == 0: # Log every 10 batches print(f"Batch {batch_idx}, Loss: {loss.item()}") # Save the trained model to /mnt/data/ for Hugging Face Space to access save_path = '/mnt/data/lora_model.pth' torch.save(model.state_dict(), save_path) print(f"Model saved at {save_path}") # Move the file to a location where we can access it for download # Here, /mnt/data is directly accessible from the Hugging Face Space interface print(f"Training completed. The model is saved and ready for download at {save_path}.") return f"Training completed. Download the model from: [Download Model](sandbox:/mnt/data/lora_model.pth)" # Gradio App def start_training_gradio(): print("Loading metadata and preparing dataset...") metadata = load_metadata(metadata_file) return train_lora(image_folder, metadata) demo = gr.Interface( fn=start_training_gradio, inputs=None, outputs="text", title="Train LoRA Model", description="Fine-tune a model using LoRA for consistent image generation." ) demo.launch()