| | from torch.utils.data import DataLoader
|
| | from level_dataset import LevelDataset
|
| | import random
|
| | from plotter import Plotter
|
| | from datetime import datetime
|
| | import os
|
| | import threading
|
| | import json
|
| | import torch.nn.functional as F
|
| | import torch
|
| |
|
| |
|
| |
|
| |
|
| | def create_dataloaders(json_path, val_json, tokenizer, data_mode, augment, num_tiles,
|
| | negative_prompt_training, block_embeddings, batch_size):
|
| |
|
| | train_dataset = LevelDataset(
|
| | json_path=json_path,
|
| | tokenizer=tokenizer,
|
| | shuffle=True,
|
| | mode=data_mode,
|
| | augment=augment,
|
| | num_tiles=num_tiles,
|
| | negative_captions=negative_prompt_training,
|
| | block_embeddings=block_embeddings
|
| | )
|
| | val_dataset = None
|
| | if val_json is not None:
|
| | val_dataset = LevelDataset(
|
| | json_path=val_json,
|
| | tokenizer=tokenizer,
|
| | shuffle=False,
|
| | mode=data_mode,
|
| | augment=False,
|
| | num_tiles=num_tiles,
|
| | negative_captions=negative_prompt_training,
|
| | block_embeddings=block_embeddings
|
| | )
|
| |
|
| |
|
| | train_dataloader = DataLoader(
|
| | train_dataset,
|
| | batch_size=batch_size,
|
| | shuffle=True,
|
| | num_workers=4,
|
| | drop_last=True,
|
| | persistent_workers=True
|
| | )
|
| |
|
| | val_dataloader = None
|
| | if val_dataset is not None:
|
| | val_dataloader = DataLoader(
|
| | val_dataset,
|
| | batch_size=batch_size,
|
| | shuffle=False,
|
| | num_workers=4,
|
| | drop_last=False,
|
| | persistent_workers=True
|
| | )
|
| |
|
| | return train_dataloader, val_dataloader
|
| |
|
| |
|
| | def get_random_training_samples(train_dataloader, negative_prompt_training, output_dir = None):
|
| | train_dataset = train_dataloader.dataset
|
| |
|
| | sample_indices = [random.randint(0, len(train_dataset) - 1) for _ in range(4)]
|
| |
|
| | sample_captions = [train_dataset[i][1] for i in sample_indices]
|
| | print("Sample captions:")
|
| | for caption in sample_captions:
|
| | print(caption)
|
| |
|
| | sample_negative_captions = ""
|
| | if negative_prompt_training:
|
| | sample_negative_captions = [train_dataset[i][2] for i in sample_indices]
|
| | print("Sample negative captions:")
|
| | for caption in sample_negative_captions:
|
| | print(f" NEG: {caption}")
|
| |
|
| |
|
| | if output_dir is not None:
|
| | os.makedirs(output_dir, exist_ok=True)
|
| | out_path = os.path.join(output_dir, "sample_captions.txt")
|
| | with open(out_path, "w", encoding="utf-8") as f:
|
| | f.write("Sample captions:\n")
|
| | for caption in sample_captions:
|
| | f.write(str(caption) + "\n")
|
| | if negative_prompt_training:
|
| | f.write("\nSample negative captions:\n")
|
| | for caption in sample_negative_captions:
|
| | f.write(str(caption) + "\n")
|
| | print(f"Sample captions written to {out_path}")
|
| |
|
| |
|
| | return sample_captions, sample_negative_captions
|
| |
|
| |
|
| | def start_plotter(log_file, output_dir, left_key, right_key, left_label, right_label, png_name):
|
| | formatted_date = datetime.now().strftime(r'%Y%m%d-%H%M%S')
|
| |
|
| | plotter = Plotter(log_file, update_interval=5.0, left_key=left_key, right_key=right_key,
|
| | left_label=left_label, right_label=right_label, output_png=f'{png_name}_{formatted_date}.png')
|
| | plot_thread = threading.Thread(target=plotter.start_plotting)
|
| | plot_thread.daemon = True
|
| | plot_thread.start()
|
| | print(f"{png_name} plotting enabled. Progress will be saved to {os.path.join(output_dir, f'{png_name}_{formatted_date}.png')}")
|
| | return plotter, plot_thread
|
| |
|
| |
|
| | def kill_plotter(plotter, plot_thread):
|
| | if plot_thread and plot_thread.is_alive():
|
| | plotter.stop_plotting()
|
| | plot_thread.join(timeout=5.0)
|
| | if plot_thread.is_alive():
|
| | print("Warning: Plot thread did not terminate properly")
|
| |
|
| |
|
| | def load_config_from_json(config_path):
|
| | """Load hyperparameters from a JSON config file."""
|
| | try:
|
| | with open(config_path, 'r') as f:
|
| | config = json.load(f)
|
| | print(f"Configuration loaded from {config_path}")
|
| |
|
| |
|
| | print("Loaded hyperparameters:")
|
| | for key, value in config.items():
|
| | print(f" {key}: {value}")
|
| |
|
| | return config
|
| | except (json.JSONDecodeError, FileNotFoundError) as e:
|
| | print(f"Error loading config file: {e}")
|
| | raise e
|
| |
|
| |
|
| | def update_args_from_config(args, config):
|
| | """Update argparse namespace with values from config."""
|
| |
|
| | for key, value in config.items():
|
| | if hasattr(args, key):
|
| | setattr(args, key, value)
|
| | return args
|
| |
|
| |
|
| | def get_scene_from_embeddings(image, block_embeddings):
|
| | """Code copied over from level_dataset, should give limited support for block embeddings"""
|
| |
|
| | batch_size, embedding_dim, height, width = image.shape
|
| |
|
| | flat_samples = image.permute(0, 2, 3, 1).reshape(-1, embedding_dim)
|
| |
|
| |
|
| | flat_samples = F.normalize(flat_samples, p=2, dim=1).cpu()
|
| | block_embeddings = F.normalize(block_embeddings, p=2, dim=1)
|
| |
|
| |
|
| | similarities = torch.matmul(flat_samples, block_embeddings.t())
|
| |
|
| |
|
| | indices = torch.softmax(similarities, dim=1)
|
| |
|
| |
|
| |
|
| | indices = indices.reshape(batch_size, height, width, 13)
|
| | indices = indices.permute(0, 3, 1, 2)
|
| |
|
| | image=indices.detach().cpu()
|
| | return image
|
| |
|
| |
|
| |
|