RishubhPar commited on
Commit
8b8b01e
·
verified ·
1 Parent(s): 1b23c88

added the basic files

Browse files
Files changed (2) hide show
  1. app.py +667 -0
  2. requirements.txt +78 -0
app.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ from typing import List, Tuple, Dict
4
+ import json
5
+
6
+ import torch
7
+ import gradio as gr
8
+ from PIL import Image
9
+
10
+ from model.transformer_flux import FluxTransformer2DModelwithSliderConditioning
11
+ # from diffusers import FluxTransformer2DModel
12
+ from model.sliders_model import SliderProjector, SliderProjector_wo_clip
13
+ from model.sliders_pipeline import FluxKontextSliderPipeline
14
+
15
+
16
+ from huggingface_hub import login, snapshot_download
17
+ HF_TOKEN = os.getenv("HF_TOKEN")
18
+
19
+ if HF_TOKEN:
20
+ # Auth for this process (does not print or persist the token in your logs)
21
+ login(token=HF_TOKEN)
22
+
23
+ # -----------------------------
24
+ # Environment & device
25
+ # -----------------------------
26
+ # Avoid meta-tensor init from environment leftovers
27
+ os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
28
+
29
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
30
+ print("Using device:", DEVICE)
31
+
32
+ torch.backends.cudnn.benchmark = True
33
+
34
+ # -----------------------------
35
+ # Model / pipeline loading
36
+ # -----------------------------
37
+ def load_pipeline_single_gpu(device_str: str) -> FluxKontextSliderPipeline:
38
+ pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
39
+
40
+ n_slider_layers = 4
41
+ slider_projector_out_dim = 6144
42
+ trained_models_path = "./model_weights/"
43
+ is_clip_input = True
44
+
45
+ # Load transformer fully on CPU; avoid meta tensors
46
+ transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained(
47
+ pretrained,
48
+ subfolder="transformer",
49
+ device_map=None,
50
+ low_cpu_mem_usage=False,
51
+ token=HF_TOKEN,
52
+ )
53
+ transformer.eval()
54
+ weight_dtype = transformer.dtype # keep checkpoint dtype
55
+
56
+ # Slider projector
57
+ if is_clip_input:
58
+ slider_projector = SliderProjector(
59
+ out_dim=slider_projector_out_dim, pe_dim=2, n_layers=n_slider_layers, is_clip_input=True
60
+ )
61
+ else:
62
+ slider_projector = SliderProjector_wo_clip(
63
+ out_dim=slider_projector_out_dim, pe_dim=2, n_layers=n_slider_layers
64
+ )
65
+
66
+ # putting both the models to infer
67
+ transformer.eval()
68
+ slider_projector.eval()
69
+
70
+ # Load projector weights on CPU
71
+ slider_projector_path = os.path.join(trained_models_path, "slider_projector.pth")
72
+ state_dict = torch.load(slider_projector_path)
73
+ print("state_dict keys: {}".format(state_dict.keys()))
74
+
75
+ slider_projector.load_state_dict(state_dict)
76
+ print(f"loaded slider_projector from {slider_projector_path}")
77
+ # ------------------------------- --------------------- --------------------------- #
78
+
79
+ # Build full pipeline on CPU; no device_map sharding
80
+ pipeline = FluxKontextSliderPipeline.from_pretrained(
81
+ pretrained,
82
+ transformer=transformer,
83
+ slider_projector=slider_projector,
84
+ torch_dtype=weight_dtype,
85
+ device_map=None,
86
+ low_cpu_mem_usage=False,
87
+ )
88
+
89
+ print("loading the pipeline lora weights from: {}".format(trained_models_path))
90
+
91
+ pipeline.load_lora_weights(trained_models_path)
92
+ print("loaded the pipeline with lora weights from: {}".format(trained_models_path))
93
+
94
+ # Move everything to the single device
95
+ pipeline.to(device_str)
96
+ return pipeline
97
+
98
+
99
+ PIPELINE = load_pipeline_single_gpu(DEVICE)
100
+ print(f"[init] Pipeline loaded on {DEVICE}")
101
+
102
+
103
+ # -----------------------------
104
+ # Sample Images & Precomputed Results
105
+ # -----------------------------
106
+
107
+ def create_sample_entry(name, image_filename, prompt, result_folder, num_results=5, result_pattern="image_{i}.png", precomputed_base="./sample_images/precomputed"):
108
+ """
109
+ Helper function to create a sample entry with subfolder organization.
110
+
111
+ Args:
112
+ name: Display name in dropdown
113
+ image_filename: Filename in ./sample_images/
114
+ prompt: Editing instruction
115
+ result_folder: Subfolder name in precomputed directory
116
+ num_results: Number of precomputed results (default 5)
117
+ result_pattern: Filename pattern, {i} will be replaced with 0,1,2,3,4 (default "image_{i}.png")
118
+ precomputed_base: Base path for precomputed results (default "./sample_images/precomputed")
119
+ """
120
+ return {
121
+ "name": name,
122
+ "image_path": f"./sample_images/{image_filename}",
123
+ "prompt": prompt,
124
+ "precomputed_results": [f"{precomputed_base}/{result_folder}/{result_pattern.format(i=i)}" for i in range(num_results)]
125
+ }
126
+
127
+ def load_samples_from_config(config_file="sample_config.json"):
128
+ """Load sample data from a JSON configuration file."""
129
+ if os.path.exists(config_file):
130
+ try:
131
+ with open(config_file, 'r') as f:
132
+ return json.load(f)
133
+ except Exception as e:
134
+ print(f"Error loading sample config: {e}")
135
+ return []
136
+
137
+ def discover_samples_automatically(sample_dir="./sample_images", precomputed_dir="./sample_images/precomputed"):
138
+ """Automatically discover samples based on directory structure with subfolders."""
139
+ discovered_samples = []
140
+
141
+ if not os.path.exists(sample_dir) or not os.path.exists(precomputed_dir):
142
+ return discovered_samples
143
+
144
+ # Look for subfolders in precomputed directory
145
+ for subfolder in os.listdir(precomputed_dir):
146
+ subfolder_path = os.path.join(precomputed_dir, subfolder)
147
+ if os.path.isdir(subfolder_path):
148
+ # Look for sequential result files in subfolder
149
+ precomputed_files = []
150
+ for i in range(0, 15): # Check for up to 15 results starting from 0
151
+ # Try different patterns
152
+ for pattern in [f"image_{i}.png", f"image_{i}.jpg", f"{i}.jpg", f"{i}.png", f"result_{i}.jpg", f"output_{i}.png"]:
153
+ result_path = os.path.join(subfolder_path, pattern)
154
+ if os.path.exists(result_path):
155
+ precomputed_files.append(result_path)
156
+ break
157
+ else:
158
+ # If no file with this index found, stop looking (but continue if we found at least one)
159
+ if i == 0 and not precomputed_files:
160
+ continue # Keep trying from index 0
161
+ elif not precomputed_files:
162
+ break # No files found at all
163
+ else:
164
+ break # Found some files but this index is missing, stop here
165
+
166
+ if precomputed_files:
167
+ # Try to find corresponding source image
168
+ img_path = None
169
+ # Common naming patterns for source images
170
+ base_name = subfolder.split('_')[0] # e.g., "portrait" from "portrait_smile"
171
+ for ext in ['.jpg', '.jpeg', '.png']:
172
+ candidate = os.path.join(sample_dir, f"{base_name}{ext}")
173
+ if os.path.exists(candidate):
174
+ img_path = candidate
175
+ break
176
+
177
+ if img_path:
178
+ sample = {
179
+ "name": f"{subfolder.replace('_', ' ').title()} - Auto-discovered",
180
+ "image_path": img_path,
181
+ "prompt": f"Edit: {subfolder.replace('_', ' ')}", # Default prompt
182
+ "precomputed_results": precomputed_files
183
+ }
184
+ discovered_samples.append(sample)
185
+
186
+ return discovered_samples
187
+
188
+ # Main sample data - using your actual folder structure
189
+ SAMPLE_DATA = [
190
+ create_sample_entry("Stylization", "aesthetic_model2_vangogh.png", "Transform the image into a Van Gogh Style painting", "aesthetic_model2_vangogh", 11),
191
+ create_sample_entry("Weather Change", "enfield3_winter_snow.png", "Transform the scene into winter season with heavy snowfall", "enfield3_winter_snow", 11),
192
+ create_sample_entry("Illumination Change", "light_lamp_blue_side.png", "Turn on the lamp with blue lighting", "light_lamp_blue_side", 11),
193
+ create_sample_entry("Appearance Change", "jackson_fluffy.png", "Transform his jacket into a blue fluffy fur jacket", "jackson_fluffy", 11),
194
+ create_sample_entry("Scene Edit", "venice1_grow_ivy.png", "Grow ivy on the walls of the buildings on the side", "venice1_grow_ivy", 11)
195
+ ]
196
+
197
+ # Add more samples using the helper function
198
+ # Modify these examples or add your own:
199
+
200
+ ADDITIONAL_SAMPLES = [
201
+ # Add your own samples here following your folder structure:
202
+ #
203
+ # For your structure (./sample_images/precomputed/folder_name/image_0.png, image_1.png, etc.):
204
+ # create_sample_entry("Display Name", "your_image.png", "editing prompt", "folder_name", 12),
205
+ #
206
+ # Examples based on your pattern:
207
+ # create_sample_entry("New Sample", "new_image.png", "apply some effect", "new_folder", 12),
208
+ # create_sample_entry("Another Edit", "source.png", "different editing instruction", "another_folder", 10),
209
+
210
+ # Note:
211
+ # - Images should be in ./sample_images/
212
+ # - Precomputed results should be in ./sample_images/precomputed/folder_name/
213
+ # - Default pattern is image_0.png, image_1.png, etc.
214
+ # - Adjust the number (12) to match how many results you have
215
+ ]
216
+
217
+ # Extend the main sample data with additional samples
218
+ SAMPLE_DATA.extend(ADDITIONAL_SAMPLES)
219
+
220
+ # Optional: Auto-discover additional samples from directories
221
+ # Uncomment to automatically find additional samples beyond the manual ones above:
222
+ # AUTO_DISCOVERED = discover_samples_automatically()
223
+ # if AUTO_DISCOVERED:
224
+ # print(f"Auto-discovered {len(AUTO_DISCOVERED)} additional samples:")
225
+ # for sample in AUTO_DISCOVERED:
226
+ # print(f" - {sample['name']}")
227
+ # SAMPLE_DATA.extend(AUTO_DISCOVERED)
228
+
229
+ # Optional: Load samples from external JSON config
230
+ # CONFIG_SAMPLES = load_samples_from_config("sample_config.json")
231
+ # SAMPLE_DATA.extend(CONFIG_SAMPLES)
232
+
233
+ def load_sample_image(image_path: str) -> Image.Image:
234
+ """Load a sample image, with fallback to a placeholder if file doesn't exist."""
235
+ try:
236
+ if os.path.exists(image_path):
237
+ return Image.open(image_path)
238
+ else:
239
+ # Create a placeholder image if sample doesn't exist
240
+ placeholder = Image.new('RGB', (512, 512), color=(200, 200, 200))
241
+ return placeholder
242
+ except Exception as e:
243
+ print(f"Error loading sample image {image_path}: {e}")
244
+ # Return a placeholder image
245
+ placeholder = Image.new('RGB', (512, 512), color=(200, 200, 200))
246
+ return placeholder
247
+
248
+ def load_precomputed_results(result_paths: List[str]) -> List[Image.Image]:
249
+ """Load precomputed result images, with fallbacks for missing files."""
250
+ results = []
251
+ for path in result_paths:
252
+ try:
253
+ if os.path.exists(path):
254
+ results.append(Image.open(path))
255
+ else:
256
+ # Create placeholder result
257
+ placeholder = Image.new('RGB', (512, 512), color=(150, 150, 150))
258
+ results.append(placeholder)
259
+ except Exception as e:
260
+ print(f"Error loading precomputed result {path}: {e}")
261
+ placeholder = Image.new('RGB', (512, 512), color=(150, 150, 150))
262
+ results.append(placeholder)
263
+ return results
264
+
265
+
266
+ # -----------------------------
267
+ # Helpers
268
+ # -----------------------------
269
+ def resize_image(img: Image.Image, target: int = 512) -> Image.Image:
270
+ """Resize shortest side to target, then center-crop to target x target."""
271
+ w, h = img.size
272
+ try:
273
+ resample = Image.Resampling.BICUBIC # PIL >= 10
274
+ except Exception:
275
+ resample = Image.BICUBIC
276
+
277
+ if h > w:
278
+ new_w, new_h = target, int(target * h / w)
279
+ elif h < w:
280
+ new_w, new_h = int(target * w / h), target
281
+ else:
282
+ new_w, new_h = target, target
283
+
284
+ # resizing the image to a fixed lower dimension size of 512
285
+ img = img.resize((new_w, new_h), resample)
286
+ return img
287
+
288
+
289
+ def _encode_prompt(prompt: str):
290
+ with torch.no_grad():
291
+ pe, ppe, _ = PIPELINE.encode_prompt(prompt, prompt_2=prompt)
292
+ return pe, ppe
293
+
294
+
295
+ # -----------------------------
296
+ # Inference functions
297
+ # -----------------------------
298
+ def generate_image_stack_edits(text_prompt, n_edits, input_image):
299
+ """
300
+ Compute n_edits images on a single GPU for slider values in (0,1],
301
+ return (list_of_images, first_image) so the UI shows immediately.
302
+ """
303
+ if not input_image or not text_prompt or text_prompt.startswith("Please select"):
304
+ return [], None
305
+
306
+ n = int(n_edits) if n_edits is not None else 1
307
+ n = max(1, n)
308
+ slider_values = [(i + 1) / float(n) for i in range(n)] # (0,1] inclusive
309
+
310
+ img = resize_image(input_image, 512)
311
+ pe, ppe = _encode_prompt(text_prompt)
312
+
313
+ results: List[Image.Image] = []
314
+ gen_base = 64 # deterministic seed base
315
+
316
+ # not using batching for now just a simple forward loop
317
+ # batch_size = 2
318
+ # n_batches = n // batch_size
319
+ # batched_slider_values = [[slider_values[i*batch_size: (i+1)*batch_size]] for i in range(n_batches)]
320
+ # print(f"batched_slider_values: {batched_slider_values}")
321
+
322
+ for i, sv in enumerate(slider_values):
323
+ gen = torch.Generator(device=DEVICE if DEVICE != "cpu" else "cpu").manual_seed(gen_base + i)
324
+ with torch.no_grad():
325
+ # replicating based on the number of examples in the batch size
326
+
327
+ out = PIPELINE(
328
+ image=img,
329
+ height=img.height,
330
+ width=img.width,
331
+ num_inference_steps=28,
332
+ prompt_embeds=pe,
333
+ pooled_prompt_embeds=ppe,
334
+ generator=gen,
335
+ text_condn=False,
336
+ modulation_condn=True,
337
+ slider_value=torch.tensor(sv, device=DEVICE if DEVICE != "cpu" else "cpu").reshape(1, 1),
338
+ is_clip_input=True,
339
+ )
340
+ results.append(out.images[0])
341
+
342
+ if DEVICE.startswith("cuda"):
343
+ torch.cuda.empty_cache()
344
+ gc.collect()
345
+
346
+ first = results[0] if results else None
347
+ return results, first
348
+
349
+
350
+ def generate_single_image(text_prompt, slider_value, input_image):
351
+ if not input_image or not text_prompt or text_prompt.startswith("Please select"):
352
+ return None
353
+
354
+ img = resize_image(input_image, 512)
355
+ sv = float(slider_value)
356
+ pe, ppe = _encode_prompt(text_prompt)
357
+
358
+ gen = torch.Generator(device=DEVICE if DEVICE != "cpu" else "cpu").manual_seed(64)
359
+ with torch.no_grad():
360
+ out = PIPELINE(
361
+ image=img,
362
+ height=img.height,
363
+ width=img.width,
364
+ num_inference_steps=28,
365
+ prompt_embeds=pe,
366
+ pooled_prompt_embeds=ppe,
367
+ generator=gen,
368
+ text_condn=False,
369
+ modulation_condn=True,
370
+ slider_value=torch.tensor(sv, device=DEVICE if DEVICE != "cpu" else "cpu").reshape(1, 1),
371
+ is_clip_input=True,
372
+ )
373
+ result = out.images[0]
374
+
375
+ if DEVICE.startswith("cuda"):
376
+ torch.cuda.empty_cache()
377
+ gc.collect()
378
+ return result
379
+
380
+
381
+ # -----------------------------
382
+ # Sample Loading Functions
383
+ # -----------------------------
384
+ def get_sample_by_name(sample_name: str):
385
+ """Get sample data by name."""
386
+ for sample in SAMPLE_DATA:
387
+ if sample["name"] == sample_name:
388
+ return sample
389
+ return None
390
+
391
+ def load_sample_to_main_interface(sample_name: str):
392
+ """Load selected sample to main interface with precomputed results."""
393
+ if not sample_name:
394
+ return (
395
+ None,
396
+ "Please select a sample above to see the editing instruction",
397
+ [],
398
+ None,
399
+ gr.update(minimum=0, maximum=0, step=1, value=0, label="Edit Strength Level")
400
+ )
401
+
402
+ sample = get_sample_by_name(sample_name)
403
+ if not sample:
404
+ return (
405
+ None,
406
+ "Sample not found",
407
+ [],
408
+ None,
409
+ gr.update(minimum=0, maximum=0, step=1, value=0, label="Edit Strength Level")
410
+ )
411
+
412
+ # Load sample image
413
+ sample_image = load_sample_image(sample["image_path"])
414
+ prompt = sample["prompt"]
415
+
416
+ # Load precomputed results
417
+ precomputed_images = load_precomputed_results(sample["precomputed_results"])
418
+ first_result = precomputed_images[0] if precomputed_images else None
419
+
420
+ # Update slider range for precomputed results
421
+ n_results = len(precomputed_images)
422
+ slider_update = gr.update(
423
+ minimum=0,
424
+ maximum=max(0, n_results-1),
425
+ step=1,
426
+ value=0,
427
+ label=f"Edit Strength Level (0-{n_results-1}) - Precomputed"
428
+ )
429
+
430
+ return sample_image, prompt, precomputed_images, first_result, slider_update
431
+
432
+
433
+ # -----------------------------
434
+ # Helpers
435
+ # -----------------------------
436
+ def update_slider_range(n_edits):
437
+ """Update the slider range based on number of edits."""
438
+ return gr.update(
439
+ minimum=0,
440
+ maximum=max(0, int(n_edits)-1),
441
+ step=1,
442
+ value=0,
443
+ label=f"Edit Strength Level (0-{int(n_edits)-1})"
444
+ )
445
+
446
+
447
+ def display_selected_image(slider_index: int, images_list: List[Image.Image]) -> Image.Image:
448
+ """
449
+ Display the image corresponding to the slider index from the generated images list.
450
+
451
+ Args:
452
+ slider_index: Current slider position (0-based index)
453
+ images_list: List of generated/precomputed images
454
+
455
+ Returns:
456
+ Selected image or None if invalid index/empty list
457
+ """
458
+ if not images_list or len(images_list) == 0:
459
+ return None
460
+
461
+ # Clamp index to valid range
462
+ idx = max(0, min(int(slider_index), len(images_list) - 1))
463
+ return images_list[idx]
464
+
465
+ # -----------------------------
466
+ # Gradio UI
467
+ # -----------------------------
468
+ # Add new helper function for user uploads
469
+ def process_user_upload(uploaded_image, user_prompt, n_edits_val):
470
+ """Handle user uploaded images and custom prompts."""
471
+ if uploaded_image is None:
472
+ return None, [], None, gr.update(minimum=0, maximum=0, step=1, value=0, label="Edit Strength Level")
473
+
474
+ # Resize uploaded image
475
+ processed_image = resize_image(uploaded_image, 512)
476
+
477
+ # Generate edits
478
+ generated_list, first_result = generate_image_stack_edits(user_prompt, n_edits_val, processed_image)
479
+
480
+ # Update slider range
481
+ slider_update = gr.update(
482
+ minimum=0,
483
+ maximum=max(0, len(generated_list)),
484
+ step=1,
485
+ value=0,
486
+ label=f"Edit Strength Level (0-{len(generated_list)-1})"
487
+ )
488
+
489
+ return processed_image, generated_list, first_result, slider_update
490
+
491
+ with gr.Blocks() as demo:
492
+ gr.Markdown("# Kontinuous Kontext - Continuous Strength Control for Instruction-based Image Editing")
493
+
494
+ # Add description section
495
+ gr.Markdown("""
496
+ ## About
497
+ ### Kontinuous Kontext allows you to edit a given image with a freeform text instruction and a slider strength value.
498
+ ### The slider strength enables precise control for the extent of the applied edit and generates smooth transitions between different editing levels.
499
+
500
+ ### You can either:
501
+ 1. Choose from our sample images with predefined edit instructions
502
+ 2. Upload your own image and specify custom editing instructions
503
+
504
+ Checkout the [paper](https://arxiv.org/pdf/2510.08532v1) and the [project page](https://snap-research.github.io/kontinuouskontext) for more details.
505
+ """)
506
+
507
+ # Add custom CSS for tabs
508
+ gr.Markdown("""
509
+ <style>
510
+ .tabs.svelte-710i53 {
511
+ margin-top: 2em !important;
512
+ margin-bottom: 2em !important;
513
+ }
514
+ .tabs.svelte-710i53 button {
515
+ font-size: 1.2em !important;
516
+ padding: 0.5em 2em !important;
517
+ min-width: 200px !important;
518
+ }
519
+ #sample_image, #sample_output, #upload_image, #upload_output {
520
+ min-height: 512px !important;
521
+ max-height: 512px !important;
522
+ }
523
+ </style>
524
+ """)
525
+
526
+ with gr.Tabs() as tabs:
527
+ # Common style parameters for images
528
+ IMAGE_WIDTH = 512
529
+ IMAGE_HEIGHT = 512
530
+
531
+ with gr.TabItem("📸 Examples") as tab1: # Added emoji and changed tab name
532
+ with gr.Row(equal_height=True):
533
+ with gr.Column(scale=1):
534
+ sample_dropdown = gr.Dropdown(
535
+ choices=[sample["name"] for sample in SAMPLE_DATA],
536
+ label="Select Sample Image & Prompt",
537
+ value=None
538
+ )
539
+ sample_text = gr.Textbox(lines=1, show_label=False, placeholder="Please select a sample above", interactive=False)
540
+ sample_n_edits = gr.Number(value=5, minimum=1, maximum=20, step=1, label="Number of Edits", precision=0)
541
+ sample_image = gr.Image(
542
+ type="pil",
543
+ label="Source Image",
544
+ width=IMAGE_WIDTH,
545
+ height=IMAGE_HEIGHT,
546
+ interactive=False,
547
+ elem_id="sample_image"
548
+ )
549
+ sample_button = gr.Button("Display Edits") # Added back
550
+ with gr.Column(scale=1):
551
+ with gr.Row():
552
+ sample_slider = gr.Slider(
553
+ minimum=0,
554
+ maximum=1,
555
+ step=0.1,
556
+ value=0,
557
+ label="Edit Strength",
558
+ scale=1,
559
+ min_width=100
560
+ )
561
+ sample_output = gr.Image(
562
+ type="pil",
563
+ label="Edited Output",
564
+ width=IMAGE_WIDTH,
565
+ height=IMAGE_HEIGHT,
566
+ elem_id="sample_output"
567
+ )
568
+
569
+ with gr.TabItem("⬆️ Upload Your Image") as tab2: # Added emoji and changed tab name
570
+ with gr.Row(equal_height=True):
571
+ with gr.Column(scale=1):
572
+ upload_text = gr.Textbox(lines=1, label="Enter Editing Prompt", placeholder="Describe the edit you want...")
573
+ upload_n_edits = gr.Number(value=5, minimum=1, maximum=20, step=1, label="Number of Edits", precision=0)
574
+ upload_image = gr.Image(
575
+ type="pil",
576
+ label="Upload Image",
577
+ width=IMAGE_WIDTH,
578
+ height=IMAGE_HEIGHT,
579
+ elem_id="upload_image"
580
+ )
581
+ upload_button = gr.Button("Generate Edits") # Kept consistent with sample tab
582
+ with gr.Column(scale=1):
583
+ with gr.Row():
584
+ upload_slider = gr.Slider(
585
+ minimum=0,
586
+ maximum=1,
587
+ step=0.1,
588
+ value=0,
589
+ label="Edit Strength Level",
590
+ scale=1,
591
+ min_width=100
592
+ )
593
+ upload_output = gr.Image(
594
+ type="pil",
595
+ label="Edited Output",
596
+ width=IMAGE_WIDTH,
597
+ height=IMAGE_HEIGHT,
598
+ elem_id="upload_output"
599
+ )
600
+
601
+ # States for both tabs
602
+ sample_generated_images = gr.State([])
603
+ upload_generated_images = gr.State([])
604
+
605
+ # Sample tab logic
606
+ sample_dropdown.change(
607
+ load_sample_to_main_interface,
608
+ inputs=[sample_dropdown],
609
+ outputs=[sample_image, sample_text, sample_generated_images, sample_output, sample_slider]
610
+ )
611
+
612
+ sample_button.click(
613
+ generate_image_stack_edits,
614
+ inputs=[sample_text, sample_n_edits, sample_image],
615
+ outputs=[sample_generated_images, sample_output],
616
+ ).then(
617
+ update_slider_range,
618
+ inputs=[sample_n_edits],
619
+ outputs=[sample_slider],
620
+ )
621
+
622
+ sample_slider.change(
623
+ display_selected_image,
624
+ inputs=[sample_slider, sample_generated_images],
625
+ outputs=[sample_output],
626
+ )
627
+
628
+ # Upload tab logic - Remove duplicate click handler and combine the logic
629
+ upload_button.click(
630
+ generate_image_stack_edits, # Generate images first
631
+ inputs=[upload_text, upload_n_edits, upload_image],
632
+ outputs=[upload_generated_images, upload_output],
633
+ ).then(
634
+ update_slider_range, # Then update slider range
635
+ inputs=[upload_n_edits],
636
+ outputs=[upload_slider],
637
+ )
638
+
639
+ # Update slider when n_edits changes
640
+ upload_n_edits.change(
641
+ update_slider_range,
642
+ inputs=[upload_n_edits],
643
+ outputs=[upload_slider],
644
+ )
645
+
646
+ upload_slider.change(
647
+ display_selected_image,
648
+ inputs=[upload_slider, upload_generated_images],
649
+ outputs=[upload_output],
650
+ )
651
+
652
+ # Add citation section at the bottom
653
+ gr.Markdown("""
654
+ ---
655
+ ### If you find this work useful, please cite:
656
+ ```bibtex
657
+ @article{kontinuous_kontext_2025,
658
+ title={Kontinuous Kontext: Continuous Strength Control for Instruction-based Image Editing},
659
+ author={R Parihar, O Patashnik, D Ostashev, R Venkatesh Babu, D Cohen-Or, and J Wang},
660
+ journal={Arxiv},
661
+ year={2025}
662
+ }
663
+ ```
664
+ """)
665
+
666
+ if __name__ == "__main__":
667
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)
requirements.txt ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch
3
+ absl-py==2.3.1
4
+ accelerate==1.9.0
5
+ annotated-types==0.7.0
6
+ av==15.1.0
7
+ bitsandbytes==0.46.1
8
+ certifi==2025.7.14
9
+ charset-normalizer==3.4.2
10
+ click==8.2.1
11
+ -e git+https://github.com/huggingface/diffusers@05e7a854d0a5661f5b433f6dd5954c224b104f0b#egg=diffusers
12
+ filelock==3.18.0
13
+ fsspec==2025.7.0
14
+ ftfy==6.3.1
15
+ gitdb==4.0.12
16
+ GitPython==3.1.45
17
+ grpcio==1.74.0
18
+ hf-xet==1.1.5
19
+ huggingface-hub==0.34.3
20
+ idna==3.10
21
+ imageio==2.37.0
22
+ importlib_metadata==8.7.0
23
+ Jinja2==3.1.6
24
+ lpips==0.1.4
25
+ Markdown==3.8.2
26
+ MarkupSafe==3.0.2
27
+ mpmath==1.3.0
28
+ networkx==3.4.2
29
+ numpy==2.2.6
30
+ nvidia-cublas-cu12==12.6.4.1
31
+ nvidia-cuda-cupti-cu12==12.6.80
32
+ nvidia-cuda-nvrtc-cu12==12.6.77
33
+ nvidia-cuda-runtime-cu12==12.6.77
34
+ nvidia-cudnn-cu12==9.5.1.17
35
+ nvidia-cufft-cu12==11.3.0.4
36
+ nvidia-cufile-cu12==1.11.1.6
37
+ nvidia-curand-cu12==10.3.7.77
38
+ nvidia-cusolver-cu12==11.7.1.2
39
+ nvidia-cusparse-cu12==12.5.4.2
40
+ nvidia-cusparselt-cu12==0.6.3
41
+ nvidia-ml-py==13.580.65
42
+ nvidia-nccl-cu12==2.26.2
43
+ nvidia-nvjitlink-cu12==12.6.85
44
+ nvidia-nvtx-cu12==12.6.77
45
+ nvitop==1.5.3
46
+ opencv-python==4.12.0.88
47
+ packaging==25.0
48
+ peft==0.16.0
49
+ pillow==11.3.0
50
+ platformdirs==4.3.8
51
+ protobuf==6.31.1
52
+ psutil==7.0.0
53
+ pydantic==2.11.7
54
+ pydantic_core==2.33.2
55
+ PyYAML==6.0.2
56
+ regex==2024.11.6
57
+ requests==2.32.4
58
+ safetensors==0.5.3
59
+ scipy==1.15.3
60
+ sentencepiece==0.2.0
61
+ sentry-sdk==2.34.0
62
+ smmap==5.0.2
63
+ sympy==1.14.0
64
+ tensorboard==2.20.0
65
+ tensorboard-data-server==0.7.2
66
+ tokenizers==0.21.4
67
+ torch==2.7.1
68
+ torchvision==0.22.1
69
+ tqdm==4.67.1
70
+ transformers==4.54.1
71
+ triton==3.3.1
72
+ typing-inspection==0.4.1
73
+ typing_extensions==4.14.1
74
+ urllib3==2.5.0
75
+ wandb==0.21.0
76
+ wcwidth==0.2.13
77
+ Werkzeug==3.1.3
78
+ zipp==3.23.0