miachiquier commited on
Commit
dc8da04
·
verified ·
1 Parent(s): b4170c2

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -1447
app.py DELETED
@@ -1,1447 +0,0 @@
1
- import PIL
2
- import math
3
- import torch
4
- import random
5
- import os
6
- import numpy as np
7
- import pandas as pd
8
- import gradio as gr
9
- import threading
10
- import time
11
- import zipfile
12
- import shutil
13
- import glob
14
- from pathlib import Path
15
- from torch.utils.data import DataLoader, Dataset, random_split
16
- import torchvision.transforms.v2 as transforms
17
- from PIL import Image, ImageDraw, ImageFont
18
- import imageio
19
- from tqdm import tqdm
20
- import tarfile
21
- import queue
22
- import hashlib
23
- import json
24
-
25
-
26
- # Set seeds for reproducibility
27
- torch.manual_seed(0)
28
- random.seed(0)
29
- np.random.seed(0)
30
-
31
- # Define constants
32
- IMG_SIZE = 512
33
- BATCH_SIZE = 32
34
- DEVICE = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
35
-
36
- # Add these global variables after the imports and before the CSS definition
37
- # Global variables for LoRA training
38
- lora_status = "Ready"
39
- lora_is_processing = False
40
-
41
- # Global variables for generation control
42
- generation_should_stop = False
43
- classifier_should_stop = False # New flag for classifier training
44
- embedding_should_stop = False # New flag for embedding encoding
45
- lora_should_stop = False # New flag for LoRA training
46
- generation_queue = queue.Queue()
47
- is_processing = False # Add this to prevent multiple simultaneous processes
48
-
49
- # Create temporary directories for uploads
50
- temp_dir = Path("./temp_uploads")
51
- temp_dir.mkdir(exist_ok=True, parents=True)
52
-
53
- lora_temp_dir = Path("./temp_lora_uploads")
54
- lora_temp_dir.mkdir(exist_ok=True, parents=True)
55
-
56
- # Create a global queue for real-time updates
57
- result_queue = queue.Queue()
58
- displayed_results = [] # Keep track of all displayed results
59
-
60
- # Add these global variables at the top of your file
61
- total_images_to_process = 0
62
- images_processed = 0
63
-
64
- # Add these global variables after the existing ones
65
- displayed_results_class0_to_class1 = [] # Results for class 0 to class 1
66
- displayed_results_class1_to_class0 = [] # Results for class 1 to class 0
67
-
68
- # Add global variables for caching
69
- CACHE_DIR = Path("./cached_results")
70
- CACHE_DIR.mkdir(exist_ok=True, parents=True)
71
-
72
- # CSS for styling the interface
73
- css = """
74
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
75
-
76
- body, * {
77
- font-family: 'Inter', sans-serif !important;
78
- letter-spacing: -0.01em;
79
- }
80
-
81
- .container {
82
- max-width: 1360px;
83
- margin: auto;
84
- padding-top: 2.5rem;
85
- padding-bottom: 2.5rem;
86
- }
87
-
88
- .header {
89
- text-align: center;
90
- margin-bottom: 3rem;
91
- padding-bottom: 2rem;
92
- border-bottom: 1px solid #f0f0f0;
93
- }
94
-
95
- .header h1 {
96
- font-size: 3rem;
97
- font-weight: 700;
98
- color: #222;
99
- letter-spacing: -0.03em;
100
- margin-bottom: 1rem;
101
- background: linear-gradient(90deg, #B39CD0 0%, #9D8AC7 100%);
102
- -webkit-background-clip: text;
103
- -webkit-text-fill-color: transparent;
104
- display: inline-block;
105
- }
106
-
107
- .header p {
108
- font-size: 1.1rem;
109
- color: #333;
110
- max-width: 800px;
111
- margin: 0 auto;
112
- line-height: 1.6;
113
- }
114
-
115
- .subtitle {
116
- font-size: 0.95rem;
117
- color: #777;
118
- max-width: 800px;
119
- margin: 0.5rem auto 0;
120
- line-height: 1.5;
121
- }
122
-
123
- .contact-info {
124
- font-size: 0.8rem;
125
- color: #777;
126
- margin-top: 15px;
127
- padding-top: 10px;
128
- border-top: 1px dashed #e0e0e0;
129
- width: 80%;
130
- margin-left: auto;
131
- margin-right: auto;
132
- }
133
-
134
- .paper-info {
135
- background-color: #f8f9fa;
136
- border-radius: 12px;
137
- padding: 1.8rem;
138
- margin: 1.8rem 0;
139
- box-shadow: 0 6px 20px rgba(0,0,0,0.05);
140
- border-left: 4px solid #B39CD0;
141
- }
142
-
143
- .paper-info h3 {
144
- font-size: 1.5rem;
145
- font-weight: 600;
146
- color: #B39CD0;
147
- letter-spacing: -0.02em;
148
- margin-bottom: 1rem;
149
- }
150
-
151
- .paper-info p {
152
- font-size: 1.05em;
153
- line-height: 1.7;
154
- color: #333;
155
- }
156
-
157
- .section-header {
158
- font-size: 1.8rem;
159
- font-weight: 600;
160
- color: #B39CD0;
161
- margin: 2.5rem 0 1.5rem 0;
162
- padding-bottom: 0.8rem;
163
- border-bottom: 2px solid #ECF0F1;
164
- letter-spacing: -0.02em;
165
- }
166
-
167
- .footer {
168
- text-align: center;
169
- margin-top: 3rem;
170
- padding: 1.5rem;
171
- border-top: 1px solid #ECF0F1;
172
- color: #666;
173
- background-color: #f8f9fa;
174
- border-radius: 0 0 12px 12px;
175
- }
176
-
177
- .btn-primary {
178
- background-color: #B39CD0 !important;
179
- border-color: #B39CD0 !important;
180
- transition: all 0.3s ease;
181
- font-weight: 500 !important;
182
- letter-spacing: 0.02em !important;
183
- padding: 0.6rem 1.5rem !important;
184
- border-radius: 8px !important;
185
- }
186
-
187
- .btn-primary:hover {
188
- background-color: #9D8AC7 !important;
189
- border-color: #9D8AC7 !important;
190
- }
191
-
192
- /* Hide the output directory */
193
- .hidden-element {
194
- display: none !important;
195
- }
196
-
197
- /* Additional CSS for better alignment */
198
- .container {
199
- padding: 0 1.5rem;
200
- }
201
-
202
- .main-container {
203
- display: flex;
204
- flex-direction: column;
205
- gap: 1.5rem;
206
- }
207
-
208
- .results-container {
209
- margin-top: 0;
210
- padding-top: 0;
211
- }
212
-
213
- .full-width-header {
214
- margin-bottom: 2rem;
215
- padding-bottom: 1.5rem;
216
- border-bottom: 1px solid #f0f0f0;
217
- text-align: center;
218
- }
219
-
220
- .content-row {
221
- display: flex;
222
- gap: 2rem;
223
- }
224
-
225
- .sidebar {
226
- min-width: 250px;
227
- padding-right: 1.5rem;
228
- }
229
-
230
- .section-header {
231
- margin-top: 0;
232
- }
233
-
234
- .tabs-container {
235
- margin-top: 1rem;
236
- }
237
-
238
- .gallery-container {
239
- margin-top: 1rem;
240
- }
241
-
242
- /* Hide the output directory */
243
- .hidden-element {
244
- display: none !important;
245
- }
246
-
247
- .gallery-item img {
248
- object-fit: contain !important;
249
- height: 200px !important;
250
- width: auto !important;
251
- }
252
-
253
- /* Force GIFs to restart when tab is selected */
254
- .tabs-container .tabitem[style*="display: block"] .gallery-container img {
255
- animation: none;
256
- animation: reload-animation 0.1s;
257
- }
258
-
259
- @keyframes reload-animation {
260
- 0% { opacity: 0.99; }
261
- 100% { opacity: 1; }
262
- }
263
- """
264
-
265
- # Add to your global variables
266
- current_cache_key = None
267
- is_using_default_params = False
268
-
269
- # Update the EXAMPLE_DATASETS to include direct dataset paths, embeddings, and classifiers
270
- EXAMPLE_DATASETS = [
271
- {
272
- "name": "butterfly",
273
- "display_name": "Butterfly (Monarch vs Viceroy)",
274
- "description": "Dataset containing images of Monarch and Viceroy butterflies for counterfactual generation",
275
- "path": "/proj/vondrick/datasets/magnification/butterfly.tar.gz",
276
- "direct_dataset_path": "example_images/butterfly",
277
- "checkpoint_path": "/proj/vondrick2/mia/magnificationold/output/lora/butterfly/copper-forest-49/checkpoint-1800",
278
- "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/butterfly",
279
- "classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/butterfly",
280
- "class_names": ["class0", "class1"]
281
- },
282
- {
283
- "name": "afhq",
284
- "display_name": "Cats vs. Dogs (AFHQ)",
285
- "description": "Dataset containing images of table lamps and floor lamps",
286
- "direct_dataset_path": "example_images/afhq",
287
- "checkpoint_path": None,
288
- "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/afhq",
289
- "classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/afhq",
290
- "class_names": ["class0", "class1"]
291
- },
292
- {
293
- "name": "lamp",
294
- "display_name": "Lamps",
295
- "description": "Dataset containing images of table lamps and floor lamps",
296
- "path": "compressed_datasets/lampsfar.zip",
297
- "direct_dataset_path": "example_images/lamps",
298
- "checkpoint_path": "/proj/vondrick2/mia/diff-usion/lora_output_lampsfar/checkpoint-800",
299
- "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/lampsfar",
300
- "classifier_path": "/proj/vondrick2/mia/diff-usion/results/ensemble/lampsfar",
301
- "class_names": ["class0", "class1"]
302
- },
303
- {
304
- "name": "couches",
305
- "display_name": "Couches",
306
- "description": "Dataset containing images of chairs and floor",
307
- "path": "compressed_datasets/couches.zip",
308
- "direct_dataset_path": "example_images/couches",
309
- "embeddings_path": "/proj/vondrick2/mia/diff-usion/results/clip_image_embeds/couches",
310
- "checkpoint_path": "/proj/vondrick2/mia/diff-usion/lora_output/couches/checkpoint-1000",
311
- "class_names": ["class0", "class1"]
312
- }
313
- ]
314
-
315
- # Function to get available example datasets
316
- def get_example_datasets():
317
- """Get list of available example datasets"""
318
- return [dataset["name"] for dataset in EXAMPLE_DATASETS]
319
-
320
- # Function to get example dataset info
321
- def get_example_dataset_info(name):
322
- """Get information about an example dataset"""
323
- for dataset in EXAMPLE_DATASETS:
324
- if dataset["name"] == name:
325
- return dataset
326
- return None
327
-
328
- #Function to check if we're using default parameters
329
- def is_using_default_params(dataset_name, custom_tskip, num_images_per_class):
330
- """Check if we're using default parameters for the given dataset"""
331
- if dataset_name is None:
332
- return False
333
- if "butterfly" in dataset_name.lower():
334
- return (custom_tskip == 70 or custom_tskip == "70") and num_images_per_class == 10
335
- elif "lamp" in dataset_name.lower():
336
- return (custom_tskip == 85 or custom_tskip == "85") and num_images_per_class == 10
337
- elif "couch" in dataset_name.lower():
338
- return (custom_tskip == 85 or custom_tskip == "85") and num_images_per_class == 10
339
- return False
340
-
341
- # # Function to get the output directory - either cache or regular output
342
- # def get_output_directory(dataset_name, is_default_params, cache_key):
343
- # """Get the appropriate output directory based on parameters"""
344
- # if is_default_params:
345
- # # Use cache directory
346
- # cache_path = CACHE_DIR / cache_key
347
- # cache_path.mkdir(exist_ok=True, parents=True)
348
-
349
- # # Create dataset-specific directory
350
- # dataset_dir = cache_path / dataset_name.replace(" ", "_").lower()
351
- # dataset_dir.mkdir(exist_ok=True, parents=True)
352
-
353
- # # Create class-specific directories
354
- # class0_to_class1_dir = dataset_dir / "class0_to_class1"
355
- # class1_to_class0_dir = dataset_dir / "class1_to_class0"
356
- # class0_to_class1_dir.mkdir(exist_ok=True, parents=True)
357
- # class1_to_class0_dir.mkdir(exist_ok=True, parents=True)
358
-
359
- # # Create context directory
360
- # context_dir = dataset_dir / "context"
361
- # context_dir.mkdir(exist_ok=True, parents=True)
362
-
363
- # return dataset_dir, class0_to_class1_dir, class1_to_class0_dir, context_dir
364
- # else:
365
- # # Use regular output directory
366
- # output_dir = Path(f"./results/{dataset_name.replace(' ', '_').lower()}")
367
- # output_dir.mkdir(exist_ok=True, parents=True)
368
-
369
- # # Create gifs directory with class-specific subdirectories
370
- # gifs_dir = output_dir / "gifs"
371
- # gifs_dir.mkdir(exist_ok=True, parents=True)
372
-
373
- # class0_to_class1_dir = gifs_dir / "class0_to_class1"
374
- # class1_to_class0_dir = gifs_dir / "class1_to_class0"
375
- # class0_to_class1_dir.mkdir(exist_ok=True, parents=True)
376
- # class1_to_class0_dir.mkdir(exist_ok=True, parents=True)
377
-
378
- # # Create context directory
379
- # context_dir = output_dir / "context"
380
- # context_dir.mkdir(exist_ok=True, parents=True)
381
-
382
- # return output_dir, class0_to_class1_dir, class1_to_class0_dir, context_dir
383
-
384
-
385
- def has_prediction_flipped(orig_preds, new_preds):
386
- """Check if any prediction has flipped from one class to another."""
387
- return ((orig_preds.preds > 0.5) != (new_preds.preds > 0.5)).any().item()
388
-
389
- # Function to extract uploaded zip or tar.gz file
390
- def extract_archive(archive_file, extract_dir):
391
- """Extract a zip or tar.gz file to the specified directory"""
392
- # Create a temporary directory for extraction
393
- temp_dir = Path(extract_dir)
394
- temp_dir.mkdir(parents=True, exist_ok=True)
395
-
396
- # Check file extension
397
- file_path = Path(archive_file)
398
-
399
- if file_path.suffix.lower() == '.zip':
400
- # Extract the zip file
401
- with zipfile.ZipFile(archive_file, 'r') as zip_ref:
402
- zip_ref.extractall(temp_dir)
403
- elif file_path.name.endswith('.tar.gz') or file_path.name.endswith('.tgz'):
404
- # Extract the tar.gz file
405
- with tarfile.open(archive_file, 'r:gz') as tar_ref:
406
- tar_ref.extractall(temp_dir)
407
- else:
408
- raise ValueError(f"Unsupported archive format: {file_path.suffix}. Please use .zip or .tar.gz")
409
-
410
- # Check if the extracted content has class0 and class1 folders
411
- # If not, try to find them in subdirectories
412
- class0_dir = temp_dir / "class0"
413
- class1_dir = temp_dir / "class1"
414
-
415
- if not (class0_dir.exists() and class1_dir.exists()):
416
- # Look for class0 and class1 in subdirectories
417
- for subdir in temp_dir.iterdir():
418
- if subdir.is_dir():
419
- if (subdir / "class0").exists() and (subdir / "class1").exists():
420
- # Move the class directories to the temp_dir
421
- shutil.move(str(subdir / "class0"), str(class0_dir))
422
- shutil.move(str(subdir / "class1"), str(class1_dir))
423
- break
424
-
425
- # Verify that we have the required directories
426
- if not (class0_dir.exists() and class1_dir.exists()):
427
- raise ValueError("The uploaded archive must contain 'class0' and 'class1' directories or a subdirectory containing them")
428
-
429
- return str(temp_dir)
430
-
431
- # Function to handle cached results (placeholder implementation)
432
- def get_cached_result_info(name):
433
- """Get information about a cached result (placeholder)"""
434
- # This is a placeholder - in a real implementation, you'd store and retrieve cached results
435
- return None
436
-
437
- # Modify the TwoClassDataset class to accept num_samples_per_class as a parameter
438
- class TwoClassDataset(Dataset):
439
- def __init__(self, root_dir, transform=None, num_samples_per_class=None):
440
- self.root_dir = Path(root_dir)
441
- #import pdb; pdb.set_trace()
442
- self.transform = transform
443
- if 'kermany' in str(self.root_dir):
444
- #import pdb; pdb.set_trace()
445
- self.class0_dir = self.root_dir / "NORMAL"
446
- self.class1_dir = self.root_dir / "DRUSEN"
447
- elif 'kiki_bouba' in str(self.root_dir):
448
- self.class0_dir = self.root_dir / "kiki"
449
- self.class1_dir = self.root_dir / "bouba"
450
- elif 'afhq' in str(self.root_dir):
451
- self.class0_dir = self.root_dir / "dog"
452
- self.class1_dir = self.root_dir / "cat"
453
- else:
454
- self.class0_dir = self.root_dir / "class0"
455
- self.class1_dir = self.root_dir / "class1"
456
-
457
- # Get image paths
458
- #import pdb; pdb.set_trace()
459
- self.class0_images = list(self.class0_dir.glob("*.*"))
460
- self.class1_images = list(self.class1_dir.glob("*.*"))
461
-
462
- # Limit the number of samples per class if specified
463
- if num_samples_per_class is not None:
464
- self.class0_images = self.class0_images[:num_samples_per_class]
465
- self.class1_images = self.class1_images[:num_samples_per_class]
466
-
467
- # Create image list and labels
468
- self.images = self.class0_images + self.class1_images
469
- self.labels = [0] * len(self.class0_images) + [1] * len(self.class1_images)
470
-
471
- def __len__(self):
472
- return len(self.images)
473
-
474
- def __getitem__(self, idx):
475
- img_path = self.images[idx]
476
- image = Image.open(img_path).convert("RGB")
477
- label = self.labels[idx]
478
-
479
- if self.transform:
480
- image = self.transform(image)
481
-
482
- return image, label, str(img_path)
483
-
484
- def compute_lpips_similarity(images1, images2, reduction=None):
485
- """Compute LPIPS similarity between two batches of images"""
486
- # This is a placeholder - in a real implementation, you'd use a proper LPIPS model
487
- # For demo purposes, we'll just return a random similarity score
488
- batch_size = images1.shape[0]
489
- similarity = torch.rand(batch_size, device=images1.device)
490
-
491
- if reduction == "mean":
492
- return similarity.mean()
493
- return similarity
494
-
495
- def get_direction_sign(idx: int):
496
- if idx == 0:
497
- sign = -1
498
- elif idx == 1:
499
- sign = 1
500
- else:
501
- raise ValueError("Currently two direction are supported in this script")
502
- return sign
503
-
504
- def add_text_to_image(image, text):
505
- """Add text to an image at the top with a nicer design"""
506
- draw = ImageDraw.Draw(image)
507
- # Use a default font
508
- try:
509
- font = ImageFont.truetype("arial.ttf", 24)
510
- except:
511
- font = ImageFont.load_default()
512
-
513
- # Add a semi-transparent gradient background for better readability
514
- text_width, text_height = draw.textsize(text, font=font) if hasattr(draw, 'textsize') else (200, 30)
515
-
516
- # Create gradient background
517
- for i in range(40):
518
- alpha = int(180 - i * 4) # Fade from 180 to 20 alpha
519
- if alpha < 0:
520
- alpha = 0
521
- draw.rectangle([(0, i), (image.width, i)], fill=(0, 0, 0, alpha))
522
-
523
- # Draw text at the top of the image
524
- draw.text((15, 10), text, fill="white", font=font)
525
- return image
526
-
527
- def create_gif(img1, img2, output_path):
528
- """Create a GIF that alternates between two images with elegant labels"""
529
- # Create copies of the images to avoid modifying the originals
530
- img1_copy = img1.copy()
531
- img2_copy = img2.copy()
532
-
533
- # Add labels to the images
534
- draw1 = ImageDraw.Draw(img1_copy)
535
- draw2 = ImageDraw.Draw(img2_copy)
536
-
537
- try:
538
- # Use a larger font size for better visibility
539
- font = ImageFont.truetype("arial.ttf", 36) # Increased from 28 to 36
540
- except:
541
- font = ImageFont.load_default()
542
-
543
- # Add a subtle shadow effect for better visibility
544
- padding = 15
545
-
546
- # Original image - add text with shadow effect
547
- # First draw shadow/outline
548
- for offset in [(1,1), (-1,1), (1,-1), (-1,-1)]:
549
- draw1.text(
550
- (padding + offset[0], padding + offset[1]),
551
- "Original",
552
- fill=(0, 0, 0, 180),
553
- font=font
554
- )
555
-
556
- # Then draw the main text
557
- draw1.text(
558
- (padding, padding),
559
- "Original",
560
- fill=(255, 255, 255, 230),
561
- font=font
562
- )
563
-
564
- # Generated image - add text with shadow effect
565
- # First draw shadow/outline
566
- for offset in [(1,1), (-1,1), (1,-1), (-1,-1)]:
567
- draw2.text(
568
- (padding + offset[0], padding + offset[1]),
569
- "Generated",
570
- fill=(0, 0, 0, 180),
571
- font=font
572
- )
573
-
574
- # Then draw the main text
575
- draw2.text(
576
- (padding, padding),
577
- "Generated",
578
- fill=(255, 255, 255, 230),
579
- font=font
580
- )
581
-
582
- # Increase duration to 1 second per image (1000ms)
583
- imageio.mimsave(output_path, [img1_copy, img2_copy], duration=1000, loop=0)
584
- return output_path
585
-
586
- # Modify the update_progress_status function to be more informative
587
- def update_progress_status():
588
- """Update the progress status for the counterfactual generation"""
589
- global images_processed, total_images_to_process, is_processing
590
-
591
- if not is_processing:
592
- if images_processed > 0:
593
- return f"Processing complete. Generated {images_processed} counterfactual images."
594
- return "Ready to process images."
595
-
596
- if total_images_to_process == 0:
597
- return "Preparing to process images..."
598
-
599
- percentage = (images_processed / total_images_to_process) * 100
600
- return f"Progress: {images_processed}/{total_images_to_process} images processed ({percentage:.1f}%)"
601
-
602
- # Add function to cancel generation
603
- def cancel_generation():
604
- """Cancel all ongoing processes"""
605
- global generation_should_stop, classifier_should_stop, embedding_should_stop, lora_should_stop
606
-
607
- # Set all stop flags
608
- generation_should_stop = True
609
- classifier_should_stop = True
610
- embedding_should_stop = True
611
- lora_should_stop = True
612
-
613
- return "All processes have been requested to stop. This may take a moment to complete."
614
-
615
- def save_results_to_cache(output_dir, cache_key):
616
- """Save generated results to cache directory"""
617
- cache_path = CACHE_DIR / cache_key
618
- cache_path.mkdir(exist_ok=True, parents=True)
619
-
620
- # Copy gifs directory
621
- output_gifs_dir = Path(output_dir) / "gifs"
622
- cache_gifs_dir = cache_path / "gifs"
623
-
624
- if output_gifs_dir.exists():
625
- # Remove existing cache if it exists
626
- if cache_gifs_dir.exists():
627
- shutil.rmtree(cache_gifs_dir)
628
-
629
- # Copy the new results, maintaining subdirectory structure
630
- shutil.copytree(output_gifs_dir, cache_gifs_dir)
631
-
632
- # Copy context images if they exist
633
- output_context_dir = Path(output_dir) / "context"
634
- cache_context_dir = cache_path / "context"
635
-
636
- if output_context_dir.exists():
637
- if cache_context_dir.exists():
638
- shutil.rmtree(cache_context_dir)
639
- shutil.copytree(output_context_dir, cache_context_dir)
640
-
641
- # Update the process_with_selected_dataset function to handle the new directory structure
642
- def process_with_selected_dataset(zip_file, output_dir, dataset_display_name, checkpoint_path=None, train_clf=True,
643
- is_direct_path=False, direct_path=None, embeddings_path=None,
644
- classifier_path=None, use_classifier_stopping=True, custom_tskip=85,
645
- manip_val=2):
646
-
647
- print(f"\nProcessing with dataset: {dataset_display_name}")
648
-
649
- # Find the selected dataset
650
- selected_dataset = None
651
- for dataset in EXAMPLE_DATASETS:
652
- if dataset["display_name"] == dataset_display_name:
653
- selected_dataset = dataset
654
- break
655
-
656
- if not selected_dataset:
657
- print("Error: No dataset selected")
658
- return "No dataset selected", [], [], [], "Error: No dataset selected", None, None
659
-
660
- # Generate cache key
661
- cache_key = get_cache_key(
662
- selected_dataset["name"], checkpoint_path, False, embeddings_path,
663
- classifier_path, use_classifier_stopping, custom_tskip,
664
- manip_val,
665
- )
666
-
667
- print(f"Generated cache key: {cache_key}")
668
-
669
- # Check if cache exists
670
- cache_path = CACHE_DIR / cache_key
671
- dataset_dir = cache_path / "gifs"
672
- print(f"Looking for cache in: {cache_path}")
673
- print(f"Looking for gifs in: {dataset_dir}")
674
- print(f"Cache exists: {cache_path.exists()}")
675
- print(f"Gifs dir exists: {dataset_dir.exists()}")
676
-
677
- #import pdb; pdb.set_trace()
678
- if cache_path.exists() and dataset_dir.exists():
679
- current_cache_key = cache_key
680
- print(f"Found cached results for key: {cache_key}")
681
-
682
- # Get paths to class-specific directories
683
- class0_to_class1_dir = dataset_dir / "class0_to_class1"
684
- class1_to_class0_dir = dataset_dir / "class1_to_class0"
685
- context_dir = cache_path/ "context"
686
-
687
- # Get all GIF paths
688
- class0_to_class1_gifs = list(class0_to_class1_dir.glob("*.gif")) if class0_to_class1_dir.exists() else []
689
- class1_to_class0_gifs = list(class1_to_class0_dir.glob("*.gif")) if class1_to_class0_dir.exists() else []
690
-
691
- # Sort the GIFs by filename for consistent ordering
692
- class0_to_class1_gifs.sort(key=lambda p: p.name)
693
- class1_to_class0_gifs.sort(key=lambda p: p.name)
694
-
695
- # Get context images
696
- class0_context = context_dir / "class0_sample.jpg" if (context_dir / "class0_sample.jpg").exists() else None
697
- class1_context = context_dir / "class1_sample.jpg" if (context_dir / "class1_sample.jpg").exists() else None
698
-
699
- # Convert paths to strings
700
- class0_to_class1_paths = [str(p) for p in class0_to_class1_gifs]
701
- class1_to_class0_paths = [str(p) for p in class1_to_class0_gifs]
702
- all_gifs = class0_to_class1_paths + class1_to_class0_paths
703
-
704
- # Update the global gallery variables
705
- global displayed_results, displayed_results_class0_to_class1, displayed_results_class1_to_class0
706
- displayed_results = all_gifs
707
- displayed_results_class0_to_class1 = class0_to_class1_paths
708
- displayed_results_class1_to_class0 = class1_to_class0_paths
709
-
710
- status_message = f"Using cached results with t-skip={custom_tskip}, manip_scale={manip_val}"
711
-
712
- # Return cached results
713
- return (
714
- "Using cached results for default parameters.",
715
- displayed_results,
716
- displayed_results_class0_to_class1,
717
- displayed_results_class1_to_class0,
718
- status_message,
719
- str(class0_context) if class0_context else None,
720
- str(class1_context) if class1_context else None
721
- )
722
-
723
- else:
724
- print("No cached results found, processing dataset...")
725
- return "No cached results found, processing dataset...", [], [], [], "No cached results found, processing dataset...", None, None
726
- return
727
-
728
-
729
-
730
- # def process_and_clear(example_datasets_dropdown, checkpoint_path_state,
731
- # is_direct_path_state, direct_path_state, embeddings_path_state,
732
- # classifier_path_state, use_classifier_stopping, custom_tskip,
733
- # manip_val):
734
- # """Clear folders first, then process the dataset"""
735
- # # Clear folders first
736
- # clear_output_folders()
737
-
738
- # # Then process the dataset
739
- # return process_with_selected_dataset(
740
- # None, # input_zip (always None)
741
- # "./output", # output_dir (hardcoded)
742
- # example_datasets_dropdown,
743
- # checkpoint_path_state,
744
- # False, # train_clf (always False)
745
- # is_direct_path_state,
746
- # direct_path_state,
747
- # embeddings_path_state,
748
- # classifier_path_state,
749
- # use_classifier_stopping,
750
- # custom_tskip,
751
- # manip_val
752
- # )
753
-
754
- def process_and_clear(example_datasets_dropdown, checkpoint_path_state,
755
- is_direct_path_state, direct_path_state, embeddings_path_state,
756
- classifier_path_state, use_classifier_stopping, custom_tskip,
757
- manip_val):
758
- """Clear galleries first, then process the dataset"""
759
- # Clear galleries but keep example images
760
- clear_output_folders()
761
-
762
- # Process the dataset
763
- result = process_with_selected_dataset(
764
- None, # input_zip (always None)
765
- "./output", # output_dir (hardcoded)
766
- example_datasets_dropdown,
767
- checkpoint_path_state,
768
- False, # train_clf (always False)
769
- is_direct_path_state,
770
- direct_path_state,
771
- embeddings_path_state,
772
- classifier_path_state,
773
- use_classifier_stopping,
774
- custom_tskip,
775
- manip_val
776
- )
777
-
778
- # Return all outputs except example images
779
- return (
780
- result[1], # gallery
781
- result[2], # gallery_class0_to_class1
782
- result[3], # gallery_class1_to_class0
783
- result[4], # progress_status # Don't update class1_context_image
784
- )
785
-
786
- def update_example_images(dataset_display_name):
787
- """Update the example images based on the selected dataset"""
788
- print(f"\nUpdating example images for {dataset_display_name}")
789
-
790
- # Find the dataset info
791
- selected_dataset = None
792
- for dataset in EXAMPLE_DATASETS:
793
- print(f"Checking dataset: {dataset['display_name']}", dataset_display_name)
794
- if dataset["display_name"] == dataset_display_name:
795
- selected_dataset = dataset
796
- print(f"Selected dataset: {selected_dataset}")
797
- break
798
-
799
- class_names = selected_dataset.get("class_names", None)
800
-
801
- if selected_dataset:
802
- dataset_dir = selected_dataset.get("direct_dataset_path")
803
- print(f"Dataset directory: {dataset_dir}")
804
-
805
- if dataset_dir:
806
- # Debug: List all files in the directory
807
- print("Contents of directory:")
808
- for path in Path(dataset_dir).rglob("*"):
809
- print(f" {path}")
810
-
811
- # Try to find class0 and class1 images
812
-
813
- class0_path = Path(dataset_dir) / class_names[0]
814
- class1_path = Path(dataset_dir) / class_names[1]
815
- print(f"Looking in class0: {class0_path}")
816
- print(f"Looking in class1: {class1_path}")
817
-
818
- class0_img = next((str(p) for p in Path(dataset_dir).glob(f"{class_names[0]}/*.*")), None)
819
- class1_img = next((str(p) for p in Path(dataset_dir).glob(f"{class_names[1]}/*.*")), None)
820
-
821
- print(f"Found images:\nclass0={class0_img}\nclass1={class1_img}")
822
- return class0_img, class1_img
823
-
824
- print("No images found")
825
- return None, None
826
- # Add a state variable to store the direct dataset path
827
- direct_path_state = gr.State(None)
828
- # Map display names back to internal names (add this back)
829
- def get_name_from_display(display_name):
830
- for dataset in EXAMPLE_DATASETS:
831
- if dataset["display_name"] == display_name:
832
- return dataset["name"]
833
- return None
834
-
835
- # Modify the use_selected_dataset function
836
- def use_selected_dataset(display_name):
837
- name = get_name_from_display(display_name)
838
- if not name:
839
- print("No dataset name found")
840
- return None, None, False, None, None, None
841
-
842
- dataset_info = get_example_dataset_info(name)
843
-
844
- # Check if there's a direct dataset path available
845
- if dataset_info and "direct_dataset_path" in dataset_info and os.path.exists(dataset_info["direct_dataset_path"]):
846
- print(f"Using direct dataset path: {dataset_info['direct_dataset_path']}")
847
- # Return paths for direct dataset, checkpoint, embeddings, and classifiers
848
- return None, dataset_info["checkpoint_path"], True, dataset_info["direct_dataset_path"], \
849
- dataset_info.get("embeddings_path"), dataset_info.get("classifier_path")
850
- elif dataset_info and os.path.exists(dataset_info["path"]):
851
- # Return the archive path and other paths
852
- return dataset_info["path"], dataset_info["checkpoint_path"], False, None, \
853
- dataset_info.get("embeddings_path"), dataset_info.get("classifier_path")
854
- return None, None, False, None, None, None
855
- def reset_galleries():
856
- """Reset all galleries when changing datasets or parameters"""
857
- global displayed_results, displayed_results_class0_to_class1, displayed_results_class1_to_class0
858
- global current_cache_key # Also reset the cache key
859
-
860
- displayed_results = []
861
- displayed_results_class0_to_class1 = []
862
- displayed_results_class1_to_class0 = []
863
- current_cache_key = None # Reset the cache key
864
-
865
- # Clear the result queue if it exists
866
- while not result_queue.empty():
867
- result_queue.get()
868
-
869
- return [], [], [], "Galleries reset"
870
- def clear_output_folders():
871
- """Delete the output/gifs and output/context folders and their contents"""
872
- import shutil
873
- from pathlib import Path
874
-
875
- # Folders to clear
876
- folders = ["gifs", "context"]
877
-
878
- for folder in folders:
879
- folder_path = Path("./output") / folder
880
- if folder_path.exists():
881
- shutil.rmtree(folder_path)
882
- print(f"Deleted {folder_path}")
883
-
884
- def create_gradio_interface():
885
- # Create temporary directories for uploads
886
- temp_dir = Path("./temp_uploads")
887
- temp_dir.mkdir(exist_ok=True, parents=True)
888
-
889
- clear_output_folders()
890
-
891
-
892
-
893
- lora_temp_dir = Path("./temp_lora_uploads")
894
- lora_temp_dir.mkdir(exist_ok=True, parents=True)
895
-
896
- # Get initial list of example datasets
897
- example_datasets = get_example_datasets()
898
-
899
- with gr.Blocks(css=css) as demo:
900
- # Add the header at the top level to span across all columns
901
- with gr.Row(elem_classes="full-width-header"):
902
- with gr.Column():
903
- gr.HTML("""
904
- <div class="header">
905
- <h1>DIFFusion Demo</h1>
906
- <p class="subtitle">Generate fine-grained edits to images using another class of images as guidance.</p>
907
- <p class="contact-info">For any questions/comments/issues with this demo, please email mia.chiquier@cs.columbia.edu.🤖</p>
908
- </div>
909
- """)
910
-
911
- # Main content row with sidebar, config column and results column
912
- with gr.Row(elem_classes="content-row"):
913
- # Sidebar for example datasets
914
- with gr.Column(scale=1, elem_classes="sidebar"):
915
- gr.HTML('<div class="section-header">Example Datasets</div>')
916
-
917
- # Create a dropdown for example datasets
918
- example_datasets_dropdown = gr.Dropdown(
919
- choices=[dataset["display_name"] for dataset in EXAMPLE_DATASETS],
920
- value=next((dataset["display_name"] for dataset in EXAMPLE_DATASETS if "lamp" in dataset["display_name"].lower()), None), # Set lamp as default
921
- label="Example Datasets",
922
- info="Select a pre-loaded dataset to use"
923
- )
924
-
925
- # Add dataset descriptions directly in the dropdown info
926
- dataset_descriptions = {dataset["display_name"]: dataset.get("description", "") for dataset in EXAMPLE_DATASETS}
927
-
928
- # Add some spacing
929
- gr.HTML("<div style='height: 20px;'></div>")
930
-
931
- # Add a hidden state for the dataset description (we'll still update it but not display it)
932
- dataset_description = gr.Textbox(visible=False)
933
-
934
- # Main content area
935
- with gr.Column(scale=2, elem_classes="main-container"):
936
- # Paper info and configuration
937
- with gr.Column():
938
- with gr.Column(elem_classes="paper-info"):
939
- gr.HTML("""
940
- <h3>DIFFusion Demo</h3>
941
- <p>Articulating specific visual transformations for AI image editing can be challenging. Our image-guided editing method addresses this by learning transformations directly from differences between two image groups, eliminating the need for predefined verbal descriptions. Tailored for scientific applications, it effectively reveals subtle differences in visually similar image categories. The approach also extends to marketing, where it automatically adapts new products into scenes, handling small interior design decisions to showcase items seamlessly. A Gradio demo, included in our GitHub code release, allows users to upload datasets and apply the method (note: GPU required). Explore the demo to test our method on your own images.</p>
942
- """)
943
-
944
- # Counterfactual Generation Section
945
- gr.HTML('<div class="section-header">Counterfactual Generation</div>')
946
-
947
- # with gr.Column(elem_classes="upload-info"):
948
- # gr.HTML("""
949
- # <p><strong>Dataset Format:</strong> Upload a zip file containing two folders named 'class0' and 'class1',
950
- # each containing images of the respective class.</p>
951
- # """)
952
-
953
- # with gr.Row():
954
- # input_zip = gr.File(
955
- # label="Upload Custom Dataset (ZIP or TAR.GZ file)",
956
- # file_types=[".zip", ".tar.gz", ".tgz"],
957
- # type="filepath"
958
- # )
959
- # # Hide the output directory by using elem_classes
960
- # output_dir = gr.Textbox(
961
- # label="Output Directory",
962
- # value="./output",
963
- # elem_classes="hidden-element"
964
- # )
965
-
966
- # with gr.Row():
967
- # gr.HTML('<div class="section-header">LoRA Training</div>')
968
-
969
- # with gr.Column(elem_classes="upload-info"):
970
- # gr.HTML("""
971
- # <p><strong>Dataset Format:</strong> Upload a zip file containing two folders named 'class0' and 'class1',
972
- # each containing images of the respective class for training the LoRA model.</p>
973
- # """)
974
-
975
- # with gr.Row():
976
- # lora_output_dir = gr.Textbox(
977
- # label="LoRA Output Directory",
978
- # value="./lora_output"
979
- # )
980
-
981
- # gr.HTML("""
982
- # <div class="parameter-box">
983
- # <p>Default LoRA Training Parameters:</p>
984
- # <ul>
985
- # <li>Epochs: 5</li>
986
- # <li>Learning Rate: 1e-4</li>
987
- # <li>Batch Size: 32</li>
988
- # <li>LoRA Rank: 4</li>
989
- # <li>LoRA Alpha: 32</li>
990
- # <li>Max Training Steps: 1000</li>
991
- # </ul>
992
- # </div>
993
- # """)
994
-
995
- # train_lora_btn = gr.Button("Train LoRA Model", elem_classes="btn-primary")
996
- # lora_status_box = gr.Textbox(label="LoRA Training Status", value="Ready to train LoRA model")
997
- # train_clf = gr.Checkbox(label="Train New Classifiers", value=False)
998
-
999
- with gr.Row():
1000
- use_classifier_stopping = gr.State(False)#
1001
-
1002
- custom_tskip = gr.Dropdown(
1003
- choices=[55, 60, 65, 70, 75, 80, 85, 90, 95],
1004
- value=85, # default value
1005
- label="Custom T-Skip Value",
1006
- info="Select a t-skip value",
1007
- visible=True
1008
- )
1009
-
1010
- # Add a text box for number of images per class
1011
- with gr.Row():
1012
- manip_val = gr.Dropdown(
1013
- choices=[1.0, 1.5, 2.0],
1014
- value=2.0, # default value
1015
- label="Manip scale",
1016
- info="Select a manip scale",
1017
- visible=True
1018
- )
1019
- #
1020
-
1021
- with gr.Row():
1022
- process_btn = gr.Button("Generate Counterfactuals", elem_classes="btn-primary")
1023
- cancel_btn = gr.Button("Cancel Generation", elem_classes="btn-primary")
1024
-
1025
- # Status for the main column
1026
- #status = gr.Textbox(label="Status", value="Ready to generate counterfactuals")
1027
-
1028
- # Results column
1029
- with gr.Column(scale=2, elem_classes="results-container"):
1030
- # Class Examples section header - MOVED HERE
1031
- gr.HTML('<div class="section-header">Class Examples</div>')
1032
-
1033
- # Class example images - MOVED HERE
1034
- with gr.Row():
1035
- class0_context_image = gr.Image(label="Class 0 Example", type="filepath", height=256)
1036
- class1_context_image = gr.Image(label="Class 1 Example", type="filepath", height=256)
1037
-
1038
- # Results section header
1039
- gr.HTML('<div class="section-header">Results</div>')
1040
-
1041
-
1042
-
1043
- default_dataset = next((dataset["display_name"] for dataset in EXAMPLE_DATASETS if "lamps" in dataset["display_name"].lower()), None)
1044
- if default_dataset:
1045
- # Initial load of example images
1046
- class0_img, class1_img = update_example_images(default_dataset)
1047
- if class0_img and class1_img:
1048
- class0_context_image.value = class0_img # Directly set the value
1049
- class1_context_image.value = class1_img
1050
-
1051
- print(f"Class 0 image: {class0_context_image.value}")
1052
- print(f"Class 1 image: {class1_context_image.value}")
1053
-
1054
- # Add tabs for different direction signs - make "All Results" the default tab
1055
- with gr.Tabs(elem_classes="tabs-container") as result_tabs:
1056
- with gr.TabItem("All Results"):
1057
- gallery = gr.Gallery(
1058
- label="Generated Images",
1059
- show_label=False,
1060
- elem_id="gallery_all",
1061
- columns=4, # Show 4 images per row
1062
- rows=None, # Let it adjust rows automatically
1063
- height="auto",
1064
- allow_preview=True,
1065
- preview=False,
1066
- object_fit="contain"
1067
- )
1068
-
1069
- with gr.TabItem("Class 0 → Class 1"):
1070
- gallery_class0_to_class1 = gr.Gallery(
1071
- label="Class 0 to Class 1",
1072
- show_label=False,
1073
- elem_id="gallery_0to1",
1074
- columns=4, # Show 4 images per row
1075
- rows=None, # Let it adjust rows automatically
1076
- height="auto",
1077
- allow_preview=True,
1078
- preview=True,
1079
- object_fit="contain"
1080
- )
1081
-
1082
- with gr.TabItem("Class 1 → Class 0"):
1083
- gallery_class1_to_class0 = gr.Gallery(
1084
- label="Class 1 to Class 0",
1085
- show_label=False,
1086
- elem_id="gallery_1to0",
1087
- columns=4, # Show 4 images per row
1088
- rows=None, # Let it adjust rows automatically
1089
- height="auto",
1090
- allow_preview=True,
1091
- preview=True,
1092
- object_fit="contain"
1093
- )
1094
- # with gr.TabItem("All Results"):
1095
- # gallery = gr.Gallery(
1096
- # columns=[3],
1097
- # rows=[3],
1098
- # height="auto",
1099
- # allow_preview=True, # Make sure this is enabled
1100
- # preview=True, # Try setting this explicitly
1101
- # object_fit="contain" # Try different fit modes
1102
- # )
1103
-
1104
- # with gr.TabItem("Class 0 → Class 1"):
1105
- # gallery_class0_to_class1 = gr.Gallery(
1106
- # columns=[3],
1107
- # rows=[3],
1108
- # height="auto",
1109
- # allow_preview=True, # Make sure this is enabled
1110
- # preview=True, # Try setting this explicitly
1111
- # object_fit="contain" # Try different fit modes
1112
- # )
1113
-
1114
- # with gr.TabItem("Class 1 → Class 0"):
1115
- # gallery_class1_to_class0 = gr.Gallery(
1116
- # columns=[3],
1117
- # rows=[3],
1118
- # height="auto",
1119
- # allow_preview=True, # Make sure this is enabled
1120
- # preview=True, # Try setting this explicitly
1121
- # object_fit="contain" # Try different fit modes
1122
- # )
1123
-
1124
- # Add a progress status box in the results column
1125
- progress_status = gr.Textbox(
1126
- label="Progress",
1127
- value="Ready to process",
1128
- interactive=False
1129
- )
1130
-
1131
- # Define state variables inside the function
1132
- #set the default to these to be those for the lamp dataset
1133
- default_dataset = next((dataset for dataset in EXAMPLE_DATASETS if "lamp" in dataset["display_name"].lower()), None)
1134
- if default_dataset:
1135
- checkpoint_path_state = gr.State(default_dataset["checkpoint_path"])
1136
- is_direct_path_state = gr.State(False)
1137
- direct_path_state = gr.State(None)
1138
- embeddings_path_state = gr.State(default_dataset["embeddings_path"])
1139
- classifier_path_state = gr.State(default_dataset["classifier_path"])
1140
-
1141
- process_btn.click(
1142
- fn=process_and_clear,
1143
- inputs=[
1144
- example_datasets_dropdown, checkpoint_path_state,
1145
- is_direct_path_state, direct_path_state, embeddings_path_state,
1146
- classifier_path_state, use_classifier_stopping, custom_tskip,
1147
- manip_val
1148
- ],
1149
- outputs=[
1150
- gallery, # Make sure these variables are all defined
1151
- gallery_class0_to_class1, # and not None
1152
- gallery_class1_to_class0,
1153
- progress_status
1154
- ] # Removed 'status' since it wasn't defined
1155
- )
1156
-
1157
- # Set up the cancel button click handler
1158
- cancel_btn.click(
1159
- fn=cancel_generation,
1160
- inputs=None,
1161
- outputs=None
1162
- )
1163
-
1164
- num_images_per_class = gr.State(10)
1165
-
1166
- example_datasets_dropdown.change(
1167
- fn=reset_galleries, # Reset galleries but not example images
1168
- inputs=None,
1169
- outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0, progress_status]
1170
- ).then( # Update dataset info
1171
- fn=update_dataset_info,
1172
- inputs=example_datasets_dropdown,
1173
- outputs=[dataset_description, checkpoint_path_state, is_direct_path_state, direct_path_state,
1174
- embeddings_path_state, classifier_path_state, custom_tskip]
1175
- ).then( # Set custom t-skip
1176
- fn=set_custom_tskip_for_dataset,
1177
- inputs=example_datasets_dropdown,
1178
- outputs=custom_tskip
1179
- ).then( # Change cache key
1180
- fn=change_cache_key,
1181
- inputs=[example_datasets_dropdown, num_images_per_class, use_classifier_stopping, custom_tskip],
1182
- outputs=None
1183
- ).then( # Update example images
1184
- fn=update_example_images,
1185
- inputs=example_datasets_dropdown,
1186
- outputs=[class0_context_image, class1_context_image]
1187
- ).then( # Automatically generate counterfactuals when dataset changes
1188
- fn=process_and_clear,
1189
- inputs=[
1190
- example_datasets_dropdown, checkpoint_path_state,
1191
- is_direct_path_state, direct_path_state, embeddings_path_state,
1192
- classifier_path_state, use_classifier_stopping, custom_tskip,
1193
- manip_val
1194
- ],
1195
- outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0,
1196
- progress_status]
1197
- )
1198
-
1199
- # Load initial example images and generate counterfactuals for default dataset (Lamps)
1200
- demo.load(
1201
- fn=update_example_images,
1202
- inputs=example_datasets_dropdown,
1203
- outputs=[class0_context_image, class1_context_image]
1204
- ).then( # Initial counterfactual generation
1205
- fn=process_and_clear,
1206
- inputs=[
1207
- example_datasets_dropdown, checkpoint_path_state,
1208
- is_direct_path_state, direct_path_state, embeddings_path_state,
1209
- classifier_path_state, use_classifier_stopping, custom_tskip,
1210
- manip_val
1211
- ],
1212
- outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0,
1213
- progress_status]
1214
- )
1215
-
1216
- # example_datasets_dropdown.change(
1217
- # fn=reset_galleries, # Reset first
1218
- # inputs=None,
1219
- # outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0, progress_status]
1220
- # ).then( # Update dataset info
1221
- # fn=update_dataset_info,
1222
- # inputs=example_datasets_dropdown,
1223
- # outputs=[dataset_description, checkpoint_path_state, is_direct_path_state, direct_path_state,
1224
- # embeddings_path_state, classifier_path_state, custom_tskip_state]
1225
- # ).then( # Set custom t-skip
1226
- # fn=set_custom_tskip_for_dataset,
1227
- # inputs=example_datasets_dropdown,
1228
- # outputs=custom_tskip
1229
- # ).then( # Change cache key
1230
- # fn=change_cache_key,
1231
- # inputs=[example_datasets_dropdown, manip_val, use_classifier_stopping, custom_tskip],
1232
- # outputs=None
1233
- # ).then( # Update example images
1234
- # fn=lambda display_name: update_example_images(display_name),
1235
- # inputs=example_datasets_dropdown,
1236
- # outputs=[class0_context_image, class1_context_image]
1237
- # )
1238
-
1239
-
1240
- # process_btn.click(
1241
- # fn=process_and_clear,
1242
- # inputs=[
1243
- # example_datasets_dropdown, checkpoint_path_state,
1244
- # is_direct_path_state, direct_path_state, embeddings_path_state,
1245
- # classifier_path_state, use_classifier_stopping, custom_tskip,
1246
- # manip_val
1247
- # ],
1248
- # outputs=[status, gallery, gallery_class0_to_class1, gallery_class1_to_class0,
1249
- # progress_status, class0_context_image, class1_context_image]
1250
- # )
1251
-
1252
-
1253
- # # Set up the click event for LoRA training
1254
- # train_lora_btn.click(
1255
- # fn=start_lora_training,
1256
- # inputs=[input_zip, lora_output_dir],
1257
- # outputs=[lora_status_box]
1258
- # )
1259
-
1260
- # # Set up periodic status checking for LoRA training
1261
- # demo.load(
1262
- # fn=check_lora_status,
1263
- # inputs=None,
1264
- # outputs=lora_status_box,
1265
- # every=5 # Check every 5 seconds
1266
- # )
1267
-
1268
- # Add a periodic refresh for the galleries
1269
- # Add a periodic refresh for the galleries
1270
-
1271
-
1272
- # Add this event handler:
1273
- # example_datasets_dropdown.change(
1274
- # fn=reset_galleries,
1275
- # inputs=None,
1276
- # outputs=[gallery, gallery_class0_to_class1, gallery_class1_to_class0, progress_status]
1277
- # )
1278
-
1279
- return demo
1280
-
1281
-
1282
-
1283
- def update_dataset_info(dataset_display_name):
1284
- """Update dataset description and paths when dropdown changes"""
1285
- # Find the selected dataset
1286
- selected_dataset = None
1287
- for dataset in EXAMPLE_DATASETS:
1288
- if dataset["display_name"] == dataset_display_name:
1289
- selected_dataset = dataset
1290
- break
1291
-
1292
- if not selected_dataset:
1293
- return "No dataset selected", None, False, None, None, None, None
1294
-
1295
- # Get dataset description
1296
- description = selected_dataset.get("description", "No description available")
1297
-
1298
- # Get paths
1299
- checkpoint_path = selected_dataset.get("checkpoint_path", None)
1300
- direct_path = selected_dataset.get("direct_dataset_path", None)
1301
- is_direct_path = direct_path is not None
1302
- embeddings_path = selected_dataset.get("embeddings_path", None)
1303
- classifier_path = selected_dataset.get("classifier_path", None)
1304
-
1305
- # Set default custom_tskip based on dataset
1306
- custom_tskip = None
1307
- if "butterfly" in dataset_display_name.lower():
1308
- custom_tskip = 70 # Set to 70 for butterfly
1309
- elif "lamp" in dataset_display_name.lower():
1310
- custom_tskip = 85 # Set to 85 for lamp
1311
-
1312
- print(f"Setting custom_tskip to {custom_tskip} for dataset {dataset_display_name}")
1313
-
1314
- return description, checkpoint_path, is_direct_path, direct_path, embeddings_path, classifier_path, custom_tskip
1315
-
1316
-
1317
-
1318
- # Function to generate a cache key based on parameters
1319
- def get_cache_key(dataset_name, checkpoint_path, train_clf, embeddings_path,
1320
- classifier_path, use_classifier_stopping, custom_tskip, manip_val):
1321
- """Generate a unique cache key based on the processing parameters"""
1322
- # Create a dictionary of parameters
1323
- params = {
1324
- "dataset_name": dataset_name,
1325
- "checkpoint_path": str(checkpoint_path),
1326
- "train_clf": train_clf,
1327
- "embeddings_path": str(embeddings_path),
1328
- "classifier_path": str(classifier_path),
1329
- "use_classifier_stopping": use_classifier_stopping,
1330
- "custom_tskip": custom_tskip,
1331
- "manip_val": float(manip_val)
1332
- }
1333
- print(f"Params: {params}")
1334
-
1335
- # Convert to JSON string and hash
1336
- params_str = json.dumps(params, sort_keys=True)
1337
- return hashlib.md5(params_str.encode()).hexdigest()
1338
- def change_cache_key(dataset_name, manip_val, use_classifier_stopping, custom_tskip):
1339
- """Change the cache key based on the selected dataset"""
1340
- global current_cache_key
1341
-
1342
- # Find the selected dataset from EXAMPLE_DATASETS
1343
- selected_dataset = None
1344
- for dataset in EXAMPLE_DATASETS:
1345
- if dataset["display_name"] == dataset_name:
1346
- selected_dataset = dataset
1347
- break
1348
-
1349
- if not selected_dataset:
1350
- print(f"No dataset found for name: {dataset_name}")
1351
- return
1352
-
1353
- # Get all parameters from the selected dataset
1354
- checkpoint_path = selected_dataset.get("checkpoint_path", None)
1355
- embeddings_path = selected_dataset.get("embeddings_path", None)
1356
- classifier_path = selected_dataset.get("classifier_path", None)
1357
-
1358
-
1359
- # Generate and set the cache key
1360
- current_cache_key = get_cache_key(
1361
- selected_dataset["name"], # Use internal name instead of display name
1362
- checkpoint_path,
1363
- False, # train_clf is always False
1364
- embeddings_path,
1365
- classifier_path,
1366
- use_classifier_stopping,
1367
- custom_tskip,
1368
- manip_val
1369
- )
1370
-
1371
- # Function to check if cached results exist
1372
- def check_cache(cache_key):
1373
- """Check if cached results exist for the given key"""
1374
- cache_path = CACHE_DIR / cache_key
1375
- return cache_path.exists() and (cache_path / "gifs").exists()
1376
-
1377
- # Add this function to create context images for each class
1378
- def create_context_image(image_paths, output_path, title, preferred_index=0):
1379
- """Create a context image showing samples from a class
1380
-
1381
- Args:
1382
- image_paths: List of paths to images in the class
1383
- output_path: Where to save the context image
1384
- title: Title for the image
1385
- preferred_index: Index of the preferred image to use (default: 0)
1386
- """
1387
- if not image_paths:
1388
- # Create a blank image if no samples are available
1389
- img = Image.new('RGB', (512, 512), color=(240, 240, 240))
1390
- draw = ImageDraw.Draw(img)
1391
- try:
1392
- font = ImageFont.truetype("arial.ttf", 32)
1393
- except:
1394
- font = ImageFont.load_default()
1395
- draw.text((256, 256), "No samples available", fill=(80, 80, 80), font=font, anchor="mm")
1396
- img.save(output_path)
1397
- return
1398
-
1399
- # Use the preferred index if available, otherwise use the first image
1400
- img_index = min(preferred_index, len(image_paths) - 1)
1401
- img = Image.open(image_paths[img_index]).convert("RGB")
1402
- img = img.resize((512, 512), Image.LANCZOS)
1403
-
1404
- # Add title
1405
- draw = ImageDraw.Draw(img)
1406
- try:
1407
- font = ImageFont.truetype("arial.ttf", 32)
1408
- except:
1409
- font = ImageFont.load_default()
1410
-
1411
- # Draw a semi-transparent background for the title
1412
- draw.rectangle([(0, 0), (img.width, 50)], fill=(0, 0, 0, 180))
1413
-
1414
- # Save the context image
1415
- img.save(output_path)
1416
-
1417
- # Fix the update_custom_tskip function
1418
- def update_custom_tskip(tskip_value):
1419
- """Update the custom_tskip input field with the value from the state"""
1420
- print(f"Updating custom_tskip input with value: {tskip_value}")
1421
- if tskip_value is None:
1422
- return ""
1423
- return str(tskip_value) # Convert to string for the text input
1424
-
1425
- # Add this function to directly set the custom_tskip based on dataset name
1426
- def set_custom_tskip_for_dataset(dataset_name):
1427
- """Set the custom_tskip value based on the selected dataset"""
1428
- if dataset_name is None:
1429
- return 85
1430
- if "butterfly" in dataset_name.lower():
1431
- return 70
1432
- elif "lamp" in dataset_name.lower():
1433
- return 85
1434
- else:
1435
- return 85
1436
-
1437
- if __name__ == "__main__":
1438
- # Uncomment this line to save current results to cache
1439
- #save_current_results_to_cache()
1440
-
1441
- demo = create_gradio_interface()
1442
- demo.launch()
1443
-
1444
- # Add these functions at the top of the file, after the imports and global variables
1445
- # but before any other function definitions
1446
-
1447
- #