| """Utility functions""" |
| import importlib |
| import random |
| import re |
| import torch |
| import numpy as np |
| from PIL import Image |
|
|
| |
| def normalize(image,rescale=True): |
| |
| if rescale: |
| image = image.float() / 255.0 |
| normalize_image = 2*image-1 |
|
|
| return normalize_image |
|
|
|
|
|
|
| def process_caption(caption): |
| """Process a caption to ensure proper formatting and remove duplicates. |
| |
| Args: |
| caption: A string containing the caption text |
| |
| Returns: |
| processed_caption: A string with processed caption |
| """ |
| if not caption.endswith('.'): |
| last_period_index = caption.rfind('.') |
| if last_period_index != -1: |
| caption = caption[:last_period_index + 1] |
| |
| sentences = re.split(r'(?<=[.!?])\s+', caption) |
| |
| unique_sentences = [] |
| for sentence in sentences: |
| if sentence and sentence not in unique_sentences: |
| unique_sentences.append(sentence) |
| |
| processed_caption = ' '.join(unique_sentences) |
| |
| return processed_caption |
|
|
|
|
| def initiate_time_steps(step, total_timestep, batch_size, config): |
| """A helper function to initiate time steps for the diffusion model. |
| |
| Args: |
| step: An integer of the constant step |
| total_timestep: An integer of the total timesteps of the diffusion model |
| batch_size: An integer of the batch size |
| config: A config object |
| |
| Returns: |
| timesteps: A tensor of shape [batch_size,] of the time steps |
| """ |
| if config.rand_timestep_equal_int: |
| |
| interval_val = total_timestep // batch_size |
| start_point = random.randint(0, interval_val - 1) |
| timesteps = torch.tensor( |
| list(range(start_point, total_timestep, interval_val)) |
| ).long() |
| return timesteps |
| elif config.random_timestep_per_iteration: |
| |
| return torch.randint(0, total_timestep, (batch_size,)).long() |
| else: |
| |
| return torch.tensor([step] * batch_size).long() |