| | import torch |
| | from torch import nn, optim |
| | from torch.utils.data import DataLoader, Dataset |
| | from torchvision import transforms, datasets, models |
| | from PIL import Image |
| | import json |
| | import os |
| | import gradio as gr |
| | import shutil |
| |
|
| | |
| | image_folder = "Images/" |
| | metadata_file = "descriptions.json" |
| |
|
| | |
| | def load_metadata(metadata_file): |
| | with open(metadata_file, 'r') as f: |
| | metadata = json.load(f) |
| | return metadata |
| |
|
| | |
| | class ImageDescriptionDataset(Dataset): |
| | def __init__(self, image_folder, metadata): |
| | self.image_folder = image_folder |
| | self.metadata = metadata |
| | self.image_names = list(metadata.keys()) |
| | self.transform = transforms.Compose([ |
| | transforms.Resize((512, 512)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
| | ]) |
| |
|
| | def __len__(self): |
| | return len(self.image_names) |
| |
|
| | def __getitem__(self, idx): |
| | image_name = self.image_names[idx] |
| | image_path = os.path.join(self.image_folder, image_name) |
| | image = Image.open(image_path).convert("RGB") |
| | description = self.metadata[image_name] |
| | image = self.transform(image) |
| | return image, description |
| |
|
| | |
| | class LoRALayer(nn.Module): |
| | def __init__(self, original_layer, rank=4): |
| | super(LoRALayer, self).__init__() |
| | self.original_layer = original_layer |
| | self.rank = rank |
| | self.lora_up = nn.Linear(original_layer.in_features, rank, bias=False) |
| | self.lora_down = nn.Linear(rank, original_layer.out_features, bias=False) |
| |
|
| | def forward(self, x): |
| | return self.original_layer(x) + self.lora_down(self.lora_up(x)) |
| |
|
| | |
| | class LoRAModel(nn.Module): |
| | def __init__(self): |
| | super(LoRAModel, self).__init__() |
| | self.backbone = models.resnet18(pretrained=True) |
| | self.backbone.fc = LoRALayer(self.backbone.fc) |
| |
|
| | def forward(self, x): |
| | return self.backbone(x) |
| |
|
| | |
| | def train_lora(image_folder, metadata): |
| | print("Starting LoRA training process...") |
| |
|
| | |
| | dataset = ImageDescriptionDataset(image_folder, metadata) |
| | dataloader = DataLoader(dataset, batch_size=8, shuffle=True) |
| |
|
| | |
| | model = LoRAModel() |
| | criterion = nn.CrossEntropyLoss() |
| | optimizer = optim.Adam(model.parameters(), lr=0.001) |
| |
|
| | |
| | num_epochs = 5 |
| | for epoch in range(num_epochs): |
| | print(f"Epoch {epoch + 1}/{num_epochs}") |
| | for batch_idx, (images, descriptions) in enumerate(dataloader): |
| | |
| | labels = torch.randint(0, 100, (images.size(0),)) |
| |
|
| | |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | |
| | |
| | optimizer.zero_grad() |
| | loss.backward() |
| | optimizer.step() |
| | |
| | if batch_idx % 10 == 0: |
| | print(f"Batch {batch_idx}, Loss: {loss.item()}") |
| |
|
| | |
| | save_path = '/mnt/data/lora_model.pth' |
| | torch.save(model.state_dict(), save_path) |
| | print(f"Model saved at {save_path}") |
| |
|
| | |
| | |
| | print(f"Training completed. The model is saved and ready for download at {save_path}.") |
| |
|
| | return f"Training completed. Download the model from: [Download Model](sandbox:/mnt/data/lora_model.pth)" |
| |
|
| | |
| | def start_training_gradio(): |
| | print("Loading metadata and preparing dataset...") |
| | metadata = load_metadata(metadata_file) |
| | return train_lora(image_folder, metadata) |
| |
|
| | demo = gr.Interface( |
| | fn=start_training_gradio, |
| | inputs=None, |
| | outputs="text", |
| | title="Train LoRA Model", |
| | description="Fine-tune a model using LoRA for consistent image generation." |
| | ) |
| |
|
| | demo.launch() |
| |
|