Leacb4 commited on
Commit
4912235
ยท
verified ยท
1 Parent(s): c97f09a

Upload optuna/optuna_optimisation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. optuna/optuna_optimisation.py +295 -0
optuna/optuna_optimisation.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Optuna hyperparameter optimization for the main CLIP model.
4
+ This script uses Optuna to find the best hyperparameters to reduce overfitting.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+
10
+ # Add parent directory to path to import modules
11
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+ import pandas as pd
16
+ import numpy as np
17
+ import torch
18
+ from torch.utils.data import DataLoader, random_split
19
+ from transformers import CLIPModel as CLIPModel_transformers
20
+ import optuna
21
+ from optuna.trial import TrialState
22
+ import warnings
23
+ import config
24
+ from main_model import (
25
+ CustomDataset,
26
+ load_models,
27
+ train_one_epoch_enhanced,
28
+ valid_one_epoch
29
+ )
30
+ from transformers import CLIPProcessor
31
+
32
+ warnings.filterwarnings("ignore")
33
+
34
+ # Global variables for data (to avoid reloading for each trial)
35
+ TRAIN_LOADER = None
36
+ VAL_LOADER = None
37
+ FEATURE_MODELS = None
38
+ DEVICE = None
39
+
40
+ def prepare_data(subset_size=5000, batch_size=32):
41
+ """
42
+ Prepare data loaders for optimization.
43
+ Use a smaller subset for faster trials.
44
+ """
45
+ print(f"\n๐Ÿ“‚ Loading data...")
46
+ df = pd.read_csv(config.local_dataset_path)
47
+ df_clean = df.dropna(subset=[config.column_local_image_path])
48
+ print(f" Total samples: {len(df_clean)}")
49
+
50
+ # Create dataset
51
+ dataset = CustomDataset(df_clean)
52
+
53
+ # Create smaller subset for optimization
54
+ subset_size = min(subset_size, len(dataset))
55
+ train_size = int(0.8 * subset_size)
56
+ val_size = subset_size - train_size
57
+
58
+ np.random.seed(42)
59
+ subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
60
+ subset_dataset = torch.utils.data.Subset(dataset, subset_indices)
61
+
62
+ train_dataset, val_dataset = random_split(
63
+ subset_dataset,
64
+ [train_size, val_size],
65
+ generator=torch.Generator().manual_seed(42)
66
+ )
67
+
68
+ # Create data loaders
69
+ train_loader = DataLoader(
70
+ train_dataset,
71
+ batch_size=batch_size,
72
+ shuffle=True,
73
+ num_workers=2,
74
+ pin_memory=True if torch.cuda.is_available() else False
75
+ )
76
+ val_loader = DataLoader(
77
+ val_dataset,
78
+ batch_size=batch_size,
79
+ shuffle=False,
80
+ num_workers=2,
81
+ pin_memory=True if torch.cuda.is_available() else False
82
+ )
83
+
84
+ print(f" Train: {len(train_dataset)} samples")
85
+ print(f" Val: {len(val_dataset)} samples")
86
+
87
+ return train_loader, val_loader
88
+
89
+ def objective(trial):
90
+ """
91
+ Objective function for Optuna optimization.
92
+ Returns validation loss to minimize.
93
+ """
94
+ global TRAIN_LOADER, VAL_LOADER, FEATURE_MODELS, DEVICE
95
+
96
+ # Suggest hyperparameters
97
+ learning_rate = trial.suggest_float("learning_rate", 1e-6, 5e-5, log=True)
98
+ temperature = trial.suggest_float("temperature", 0.05, 0.15)
99
+ alignment_weight = trial.suggest_float("alignment_weight", 0.1, 0.6)
100
+ weight_decay = trial.suggest_float("weight_decay", 1e-5, 5e-4, log=True)
101
+
102
+ print(f"\n{'='*80}")
103
+ print(f"Trial {trial.number}")
104
+ print(f" LR: {learning_rate:.2e}, Temp: {temperature:.4f}")
105
+ print(f" Align weight: {alignment_weight:.3f}, Weight decay: {weight_decay:.2e}")
106
+ print(f"{'='*80}")
107
+
108
+ # Create fresh model for this trial
109
+ clip_model = CLIPModel_transformers.from_pretrained(
110
+ 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
111
+ ).to(DEVICE)
112
+
113
+ # Optimizer with weight decay for regularization
114
+ optimizer = torch.optim.AdamW(
115
+ clip_model.parameters(),
116
+ lr=learning_rate,
117
+ weight_decay=weight_decay
118
+ )
119
+
120
+ # Create processor
121
+ processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
122
+
123
+ # Train for a few epochs (reduced for faster optimization)
124
+ num_epochs = 5
125
+ best_val_loss = float('inf')
126
+ patience_counter = 0
127
+ patience = 2
128
+
129
+ for epoch in range(num_epochs):
130
+ # Training
131
+ color_model = FEATURE_MODELS[config.color_column]
132
+ hierarchy_model = FEATURE_MODELS[config.hierarchy_column]
133
+
134
+ train_loss, metrics = train_one_epoch_enhanced(
135
+ clip_model, TRAIN_LOADER, optimizer, FEATURE_MODELS,
136
+ color_model, hierarchy_model, DEVICE, processor,
137
+ temperature=temperature, alignment_weight=alignment_weight
138
+ )
139
+
140
+ # Validation
141
+ val_loss = valid_one_epoch(
142
+ clip_model, VAL_LOADER, FEATURE_MODELS, DEVICE, processor,
143
+ temperature=temperature, alignment_weight=alignment_weight
144
+ )
145
+
146
+ print(f" Epoch {epoch+1}/{num_epochs} - Train: {train_loss:.4f}, Val: {val_loss:.4f}")
147
+
148
+ # Track best validation loss
149
+ if val_loss < best_val_loss:
150
+ best_val_loss = val_loss
151
+ patience_counter = 0
152
+ else:
153
+ patience_counter += 1
154
+
155
+ # Early stopping within trial
156
+ if patience_counter >= patience:
157
+ print(f" Early stopping at epoch {epoch+1}")
158
+ break
159
+
160
+ # Report intermediate value for pruning
161
+ trial.report(val_loss, epoch)
162
+
163
+ # Handle pruning based on intermediate value
164
+ if trial.should_prune():
165
+ print(f" Trial pruned at epoch {epoch+1}")
166
+ raise optuna.TrialPruned()
167
+
168
+ # Clean up memory
169
+ del clip_model, optimizer, processor
170
+ if torch.cuda.is_available():
171
+ torch.cuda.empty_cache()
172
+
173
+ return best_val_loss
174
+
175
+ def main():
176
+ """
177
+ Main function to run Optuna optimization.
178
+ """
179
+ global TRAIN_LOADER, VAL_LOADER, FEATURE_MODELS, DEVICE
180
+
181
+ print("="*80)
182
+ print("๐Ÿ” Optuna Hyperparameter Optimization")
183
+ print("="*80)
184
+
185
+ # Set device
186
+ DEVICE = config.device
187
+ print(f"\nDevice: {DEVICE}")
188
+
189
+ # Load feature models once
190
+ print("\n๐Ÿ”ง Loading feature models...")
191
+ FEATURE_MODELS = load_models()
192
+
193
+ # Prepare data once (use smaller subset for faster optimization)
194
+ TRAIN_LOADER, VAL_LOADER = prepare_data(subset_size=5000, batch_size=32)
195
+
196
+ # Create Optuna study
197
+ print("\n๐ŸŽฏ Creating Optuna study...")
198
+ study = optuna.create_study(
199
+ direction="minimize",
200
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=2),
201
+ study_name="clip_hyperparameter_optimization"
202
+ )
203
+
204
+ # Run optimization
205
+ print("\n๐Ÿš€ Starting optimization...")
206
+ print(f" Running 30 trials (this may take a while)...\n")
207
+
208
+ study.optimize(
209
+ objective,
210
+ n_trials=30,
211
+ timeout=None,
212
+ catch=(Exception,),
213
+ show_progress_bar=True
214
+ )
215
+
216
+ # Print results
217
+ print("\n" + "="*80)
218
+ print("โœ… Optimization Complete!")
219
+ print("="*80)
220
+
221
+ print(f"\n๐Ÿ“Š Best trial:")
222
+ trial = study.best_trial
223
+ print(f" Value (Val Loss): {trial.value:.4f}")
224
+ print(f"\n Best hyperparameters:")
225
+ for key, value in trial.params.items():
226
+ if 'learning_rate' in key or 'weight_decay' in key:
227
+ print(f" {key}: {value:.2e}")
228
+ else:
229
+ print(f" {key}: {value:.4f}")
230
+
231
+ # Save results in parent directory
232
+ results_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "optuna_results.txt")
233
+ with open(results_file, 'w') as f:
234
+ f.write("="*80 + "\n")
235
+ f.write("Optuna Hyperparameter Optimization Results\n")
236
+ f.write("="*80 + "\n\n")
237
+ f.write(f"Best trial value (validation loss): {trial.value:.4f}\n\n")
238
+ f.write("Best hyperparameters:\n")
239
+ for key, value in trial.params.items():
240
+ if 'learning_rate' in key or 'weight_decay' in key:
241
+ f.write(f" {key}: {value:.2e}\n")
242
+ else:
243
+ f.write(f" {key}: {value:.4f}\n")
244
+ f.write("\n" + "="*80 + "\n")
245
+ f.write("All trials:\n")
246
+ f.write("="*80 + "\n\n")
247
+
248
+ df_results = study.trials_dataframe()
249
+ f.write(df_results.to_string())
250
+
251
+ print(f"\n๐Ÿ’พ Results saved to: {results_file}")
252
+
253
+ # Save study for later analysis
254
+ import pickle
255
+ study_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'optuna_study.pkl')
256
+ with open(study_file, 'wb') as f:
257
+ pickle.dump(study, f)
258
+ print(f"๐Ÿ’พ Study object saved to: {study_file}")
259
+
260
+ # Print pruned trials statistics
261
+ pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
262
+ complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
263
+
264
+ print(f"\n๐Ÿ“ˆ Statistics:")
265
+ print(f" Number of finished trials: {len(study.trials)}")
266
+ print(f" Number of pruned trials: {len(pruned_trials)}")
267
+ print(f" Number of complete trials: {len(complete_trials)}")
268
+
269
+ # Visualization (optional, requires optuna-dashboard or matplotlib)
270
+ try:
271
+ from optuna.visualization import plot_optimization_history, plot_param_importances
272
+
273
+ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
274
+
275
+ # Plot optimization history
276
+ fig1 = plot_optimization_history(study)
277
+ history_file = os.path.join(parent_dir, "optuna_optimization_history.png")
278
+ fig1.write_image(history_file)
279
+ print(f"๐Ÿ“Š Optimization history saved to: {history_file}")
280
+
281
+ # Plot parameter importances
282
+ fig2 = plot_param_importances(study)
283
+ importance_file = os.path.join(parent_dir, "optuna_param_importances.png")
284
+ fig2.write_image(importance_file)
285
+ print(f"๐Ÿ“Š Parameter importances saved to: {importance_file}")
286
+ except Exception as e:
287
+ print(f"\nโš ๏ธ Visualization skipped: {e}")
288
+ print(" Install plotly and kaleido for visualizations: pip install plotly kaleido")
289
+
290
+ print("\n" + "="*80)
291
+ print("๐ŸŽ‰ Done! Update your config with the best hyperparameters.")
292
+ print("="*80)
293
+
294
+ if __name__ == "__main__":
295
+ main()