| import torch |
| from torch import nn, optim |
| from torchvision import transforms, datasets, models |
| from torch.utils.data import DataLoader, Dataset |
| from PIL import Image |
| import json |
| import os |
| import gradio as gr |
|
|
| |
| image_folder = "Images/" |
| metadata_file = "descriptions.json" |
|
|
| |
| def load_metadata(metadata_file): |
| with open(metadata_file, 'r') as f: |
| metadata = json.load(f) |
| return metadata |
|
|
| |
| class ImageDescriptionDataset(Dataset): |
| def __init__(self, image_folder, metadata): |
| self.image_folder = image_folder |
| self.metadata = metadata |
| self.image_names = list(metadata.keys()) |
| self.transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| def __len__(self): |
| return len(self.image_names) |
|
|
| def __getitem__(self, idx): |
| image_name = self.image_names[idx] |
| image_path = os.path.join(self.image_folder, image_name) |
| image = Image.open(image_path).convert("RGB") |
| description = self.metadata[image_name] |
| image = self.transform(image) |
| return image, description |
|
|
| |
| class LoRAModel(nn.Module): |
| def __init__(self): |
| super(LoRAModel, self).__init__() |
| self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) |
| |
| self.fc = nn.Linear(self.backbone.fc.in_features, 100) |
|
|
| def forward(self, x): |
| x = self.backbone(x) |
| x = self.fc(x) |
| return x |
|
|
|
|
| |
| def train_lora(image_folder, metadata): |
| print("Starting training process...") |
| |
| |
| dataset = ImageDescriptionDataset(image_folder, metadata) |
| dataloader = DataLoader(dataset, batch_size=8, shuffle=True) |
| |
| |
| model = LoRAModel() |
| criterion = nn.CrossEntropyLoss() |
| optimizer = optim.Adam(model.parameters(), lr=0.001) |
| |
| |
| num_epochs = 5 |
| for epoch in range(num_epochs): |
| print(f"Epoch {epoch + 1}/{num_epochs}") |
| for batch_idx, (images, descriptions) in enumerate(dataloader): |
| |
| |
| labels = torch.randint(0, 100, (images.size(0),)) |
|
|
| |
| outputs = model(images) |
| loss = criterion(outputs, labels) |
| |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| if batch_idx % 10 == 0: |
| print(f"Batch {batch_idx}, Loss: {loss.item()}") |
|
|
| print("Training completed.") |
|
|
| |
| def start_training_gradio(): |
| print("Preparing dataset...") |
| metadata = load_metadata(metadata_file) |
| return train_lora(image_folder, metadata) |
|
|
| |
| demo = gr.Interface( |
| fn=start_training_gradio, |
| inputs=None, |
| outputs="text", |
| title="Train LoRA on Your Dataset", |
| description="Click below to start training with the uploaded images and metadata." |
| ) |
|
|
| demo.launch() |
|
|