| import os |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as transforms |
| from torch.utils.data import DataLoader, Dataset |
| from PIL import Image |
| from transformers import T5ForConditionalGeneration, T5Tokenizer |
| import matplotlib.pyplot as plt |
| device ="cpu" |
| class TextEncoder(nn.Module): |
| def __init__(self, encoder_model_name): |
| super(TextEncoder, self).__init__() |
| self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name) |
| self.encoder = T5ForConditionalGeneration.from_pretrained(encoder_model_name) |
| self.encoder.to(device) |
|
|
| def encode_text(self, text): |
| inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
| inputs = {key: value.to(device) for key, value in inputs.items()} |
| outputs = self.encoder.encoder(**inputs) |
| embeddings = outputs.last_hidden_state[:, 0, :] |
| return embeddings |
|
|
| class ConditionalDiffusionModel(nn.Module): |
| def __init__(self): |
| super(ConditionalDiffusionModel, self).__init__() |
| self.model = nn.Sequential( |
| nn.Linear(512, 768), |
| nn.ReLU(), |
| nn.Linear(768, 64), |
| nn.ReLU(), |
| nn.Linear(64, 64) |
| ) |
|
|
| def forward(self, text_embeddings): |
| return self.model(text_embeddings) |
|
|
| class SuperResolutionDiffusionModel(nn.Module): |
| def __init__(self): |
| super(SuperResolutionDiffusionModel, self).__init__() |
| self.model = nn.Sequential( |
| nn.Conv2d(3, 64, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(64, 3, kernel_size=3, padding=1) |
| ) |
|
|
| def forward(self, input_image): |
| return self.model(input_image) |
|
|
| class TextToImageModel(nn.Module): |
| def __init__(self, text_encoder, conditional_diffusion_model, super_resolution_diffusion_model): |
| super(TextToImageModel, self).__init__() |
| self.text_encoder = text_encoder |
| self.conditional_diffusion_model = conditional_diffusion_model |
| self.super_resolution_diffusion_model = super_resolution_diffusion_model |
|
|
| def forward(self, text): |
| text_embeddings = self.text_encoder.encode_text(text) |
| image_embeddings = self.conditional_diffusion_model(text_embeddings) |
| input_image = torch.rand((1, 3, 64, 64)) |
| for i in range(6): |
| input_image = self.super_resolution_diffusion_model(input_image) |
| return input_image |
|
|
| class CustomDataset(Dataset): |
| def __init__(self, annotations_file, img_dir, transform=None): |
| with open(annotations_file, 'r') as f: |
| lines = f.readlines() |
| self.img_labels = [line.strip().split(' ', 1) for line in lines] |
| self.img_dir = img_dir |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.img_labels) |
|
|
| def __getitem__(self, idx): |
| img_name, text = self.img_labels[idx] |
| img_path = os.path.join(self.img_dir, img_name) |
| image = Image.open(img_path).convert("RGB") |
| if self.transform: |
| image = self.transform(image) |
| return text, image |
|
|
| def save_checkpoint(model, optimizer, epoch, checkpoint_path): |
| checkpoint = { |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'epoch': epoch |
| } |
| torch.save(checkpoint, checkpoint_path) |
|
|
| def load_checkpoint(model, optimizer, checkpoint_path): |
| if os.path.isfile(checkpoint_path): |
| checkpoint = torch.load(checkpoint_path) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| epoch = checkpoint['epoch'] |
| return epoch |
| else: |
| return 0 |
|
|
| def test_inference(model, text): |
| model.eval() |
| with torch.no_grad(): |
| generated_image = model(text) |
| return generated_image |
|
|
| def visualize_image(image_tensor): |
| image_tensor = image_tensor.squeeze(0).cpu().detach() |
| image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min()) |
| image_tensor = image_tensor.permute(1, 2, 0) |
| plt.imshow(image_tensor) |
| plt.show() |
|
|
| if __name__ == "__main__": |
| |
| batch_size = 4 |
| learning_rate = 1e-4 |
| num_epochs = 1000 |
| checkpoint_path = 'checkpoint.pth' |
| annotations_file = 'annotations.txt' |
| img_dir = 'images/' |
| |
| |
| text_encoder = TextEncoder("google-t5/t5-small") |
| conditional_diffusion_model = ConditionalDiffusionModel() |
| super_resolution_diffusion_model = SuperResolutionDiffusionModel() |
| text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model) |
|
|
| |
| optimizer = torch.optim.Adam(text_to_image_model.parameters(), lr=learning_rate) |
| criterion = nn.MSELoss() |
|
|
| |
| start_epoch = load_checkpoint(text_to_image_model, optimizer, checkpoint_path) |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((64, 64)), |
| transforms.ToTensor(), |
| ]) |
|
|
| |
| dataset = CustomDataset(annotations_file, img_dir, transform) |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
| |
| text_to_image_model.train() |
| for epoch in range(start_epoch, num_epochs): |
| for i, (text_batch, image_batch) in enumerate(dataloader): |
| optimizer.zero_grad() |
| images = text_to_image_model(text_batch) |
| target_images = image_batch.to(device) |
| loss = criterion(images, target_images) |
| loss.backward() |
| optimizer.step() |
| |
| print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}') |
| save_checkpoint(text_to_image_model, optimizer, epoch+1, checkpoint_path) |
|
|
| print("Training completed.") |
| |
| |
| sample_text = "A big ape." |
| generated_image = test_inference(text_to_image_model, sample_text) |
| visualize_image(generated_image) |
|
|