schrum2 commited on
Commit
b5baa0f
·
verified ·
1 Parent(s): 35bc800

Don't think I need this

Browse files
Files changed (1) hide show
  1. general_training_helper.py +0 -172
general_training_helper.py DELETED
@@ -1,172 +0,0 @@
1
- from torch.utils.data import DataLoader
2
- from level_dataset import LevelDataset
3
- import random
4
- from plotter import Plotter
5
- from datetime import datetime
6
- import os
7
- import threading
8
- import json
9
- import torch.nn.functional as F
10
- import torch
11
-
12
-
13
-
14
-
15
- def create_dataloaders(json_path, val_json, tokenizer, data_mode, augment, num_tiles,
16
- negative_prompt_training, block_embeddings, batch_size):
17
- # Initialize dataset
18
- train_dataset = LevelDataset(
19
- json_path=json_path,
20
- tokenizer=tokenizer,
21
- shuffle=True,
22
- mode=data_mode,
23
- augment=augment,
24
- num_tiles=num_tiles,
25
- negative_captions=negative_prompt_training,
26
- block_embeddings=block_embeddings
27
- )
28
- val_dataset = None
29
- if val_json is not None:
30
- val_dataset = LevelDataset(
31
- json_path=val_json,
32
- tokenizer=tokenizer,
33
- shuffle=False,
34
- mode=data_mode,
35
- augment=False,
36
- num_tiles=num_tiles,
37
- negative_captions=negative_prompt_training,
38
- block_embeddings=block_embeddings
39
- )
40
-
41
- # Create dataloader
42
- train_dataloader = DataLoader(
43
- train_dataset,
44
- batch_size=batch_size,
45
- shuffle=True,
46
- num_workers=4,
47
- drop_last=True,
48
- persistent_workers=True
49
- )
50
-
51
- val_dataloader = None
52
- if val_dataset is not None:
53
- val_dataloader = DataLoader(
54
- val_dataset,
55
- batch_size=batch_size,
56
- shuffle=False,
57
- num_workers=4,
58
- drop_last=False,
59
- persistent_workers=True
60
- )
61
-
62
- return train_dataloader, val_dataloader
63
-
64
-
65
- def get_random_training_samples(train_dataloader, negative_prompt_training, output_dir = None):
66
- train_dataset = train_dataloader.dataset
67
- # Sample four random captions from the dataset
68
- sample_indices = [random.randint(0, len(train_dataset) - 1) for _ in range(4)]
69
-
70
- sample_captions = [train_dataset[i][1] for i in sample_indices]
71
- print("Sample captions:")
72
- for caption in sample_captions:
73
- print(caption)
74
-
75
- sample_negative_captions = ""
76
- if negative_prompt_training:
77
- sample_negative_captions = [train_dataset[i][2] for i in sample_indices]
78
- print("Sample negative captions:")
79
- for caption in sample_negative_captions:
80
- print(f" NEG: {caption}")
81
-
82
- #Write captions to a file
83
- if output_dir is not None:
84
- os.makedirs(output_dir, exist_ok=True)
85
- out_path = os.path.join(output_dir, "sample_captions.txt")
86
- with open(out_path, "w", encoding="utf-8") as f:
87
- f.write("Sample captions:\n")
88
- for caption in sample_captions:
89
- f.write(str(caption) + "\n")
90
- if negative_prompt_training:
91
- f.write("\nSample negative captions:\n")
92
- for caption in sample_negative_captions:
93
- f.write(str(caption) + "\n")
94
- print(f"Sample captions written to {out_path}")
95
-
96
-
97
- return sample_captions, sample_negative_captions
98
-
99
-
100
- def start_plotter(log_file, output_dir, left_key, right_key, left_label, right_label, png_name):
101
- formatted_date = datetime.now().strftime(r'%Y%m%d-%H%M%S')
102
-
103
- plotter = Plotter(log_file, update_interval=5.0, left_key=left_key, right_key=right_key,
104
- left_label=left_label, right_label=right_label, output_png=f'{png_name}_{formatted_date}.png')
105
- plot_thread = threading.Thread(target=plotter.start_plotting)
106
- plot_thread.daemon = True
107
- plot_thread.start()
108
- print(f"{png_name} plotting enabled. Progress will be saved to {os.path.join(output_dir, f'{png_name}_{formatted_date}.png')}")
109
- return plotter, plot_thread
110
-
111
-
112
- def kill_plotter(plotter, plot_thread):
113
- if plot_thread and plot_thread.is_alive():
114
- plotter.stop_plotting()
115
- plot_thread.join(timeout=5.0)
116
- if plot_thread.is_alive():
117
- print("Warning: Plot thread did not terminate properly")
118
-
119
-
120
- def load_config_from_json(config_path):
121
- """Load hyperparameters from a JSON config file."""
122
- try:
123
- with open(config_path, 'r') as f:
124
- config = json.load(f)
125
- print(f"Configuration loaded from {config_path}")
126
-
127
- # Print the loaded config for verification
128
- print("Loaded hyperparameters:")
129
- for key, value in config.items():
130
- print(f" {key}: {value}")
131
-
132
- return config
133
- except (json.JSONDecodeError, FileNotFoundError) as e:
134
- print(f"Error loading config file: {e}")
135
- raise e
136
-
137
-
138
- def update_args_from_config(args, config):
139
- """Update argparse namespace with values from config."""
140
- # Convert config dict to argparse namespace
141
- for key, value in config.items():
142
- if hasattr(args, key):
143
- setattr(args, key, value)
144
- return args
145
-
146
-
147
- def get_scene_from_embeddings(image, block_embeddings):
148
- """Code copied over from level_dataset, should give limited support for block embeddings"""
149
- # Reshape sample to [batch_size * height * width, embedding_dim]
150
- batch_size, embedding_dim, height, width = image.shape
151
-
152
- flat_samples = image.permute(0, 2, 3, 1).reshape(-1, embedding_dim)
153
-
154
- # Normalize vectors for cosine similarity
155
- flat_samples = F.normalize(flat_samples, p=2, dim=1).cpu()
156
- block_embeddings = F.normalize(block_embeddings, p=2, dim=1)
157
-
158
- # Calculate cosine similarity between each position and all tile embeddings
159
- similarities = torch.matmul(flat_samples, block_embeddings.t())
160
-
161
- # Get indices of most similar tiles
162
- indices = torch.softmax(similarities, dim=1)
163
-
164
-
165
- # Reshape back to [batch_size, height, width]
166
- indices = indices.reshape(batch_size, height, width, 13)
167
- indices = indices.permute(0, 3, 1, 2)
168
-
169
- image=indices.detach().cpu()
170
- return image
171
-
172
-