tomo2chin2 commited on
Commit
7c4c193
·
verified ·
1 Parent(s): 13dfefe

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +858 -659
app.py CHANGED
@@ -12,8 +12,8 @@ import insightface
12
  import onnxruntime
13
  import numpy as np
14
  import gradio as gr
15
- # import threading # Not explicitly used, can be removed if StreamerThread is not used
16
- # import queue # Not explicitly used, can be removed if StreamerThread is not used
17
 
18
  from datasets import Dataset, Features, Image as DatasetImage, Value, load_dataset, concatenate_datasets
19
  from PIL import Image
@@ -25,7 +25,7 @@ import concurrent.futures
25
  from moviepy.editor import VideoFileClip
26
 
27
  from face_swapper import Inswapper, paste_to_whole
28
- from face_analyser import detect_conditions, get_analysed_data, swap_options_list as original_swap_options_list
29
  from face_parsing import init_parsing_model, get_parsed_mask, mask_regions, mask_regions_to_list
30
  from face_enhancer import get_available_enhancer_names, load_face_enhancer_model, cv2_interpolations
31
  from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_list_by_lengths, merge_img_sequence_from_ref, create_image_grid
@@ -35,7 +35,7 @@ from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_
35
  parser = argparse.ArgumentParser(description="Free Face Swapper")
36
  parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
37
  parser.add_argument("--batch_size", help="Gpu batch size", default=32)
38
- parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=True) # Default changed to True based on original
39
  parser.add_argument(
40
  "--colab", action="store_true", help="Enable colab mode", default=False
41
  )
@@ -49,12 +49,12 @@ DEF_OUTPUT_PATH = user_args.out_dir
49
  BATCH_SIZE = int(user_args.batch_size)
50
  WORKSPACE = None
51
  OUTPUT_FILE = None
52
- CURRENT_FRAME = None # Seems unused
53
- STREAMER = None # Related to Stream input type, which is hidden
54
  DETECT_CONDITION = "best detection"
55
  DETECT_SIZE = 640
56
  DETECT_THRESH = 0.7
57
- NUM_OF_SRC_SPECIFIC = 10 # For hidden specific face UI
58
  MASK_INCLUDE = [
59
  "Skin",
60
  "R-Eyebrow",
@@ -73,17 +73,15 @@ MASK_ERODE_AMOUNT = 0.05
73
 
74
  FACE_SWAPPER = None
75
  FACE_ANALYSER = None
76
- # FACE_ENHANCER = "GFPGAN" # This is a default string, the model object is FACE_ENHANCER_MODEL
77
- FACE_ENHANCER_MODEL = None # To store the loaded enhancer model object
78
  FACE_PARSER = None
79
- FACE_ENHANCER_LIST = ["None"] # "None" as a string option
80
  FACE_ENHANCER_LIST.extend(get_available_enhancer_names())
81
  FACE_ENHANCER_LIST.extend(cv2_interpolations)
82
 
83
- swap_options_list_ui = [opt for opt in original_swap_options_list if opt != "Specific Face"]
84
-
85
-
86
  ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
 
 
87
  PROVIDER = ["CPUExecutionProvider"]
88
  if USE_CUDA:
89
  available_providers = onnxruntime.get_available_providers()
@@ -91,781 +89,982 @@ if USE_CUDA:
91
  if "CUDAExecutionProvider" in available_providers:
92
  print("\n********** Running on CUDA **********\n")
93
  PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
 
94
  else:
95
- USE_CUDA = False # Correctly set USE_CUDA to False if provider not found
96
  print("\n********** CUDA unavailable running on CPU **********\n")
97
  else:
98
- # USE_CUDA = False # Already false or set by arg
99
  print("\n********** Running on CPU **********\n")
100
 
101
  device = "cuda" if USE_CUDA else "cpu"
102
- EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" and torch.cuda.is_available() else None # Added torch.cuda.is_available() check
103
-
104
- # print(onnxruntime.get_available_providers())
105
- # print(f"Torch CUDA available: {torch.cuda.is_available()}")
106
- # if torch.cuda.is_available():
107
- # print(f"Torch CUDA device count: {torch.cuda.device_count()}")
108
- # print(f"Torch current CUDA device: {torch.cuda.current_device()}")
109
- # if torch.cuda.device_count() > 0:
110
- # print(f"Torch CUDA device name: {torch.cuda.get_device_name(0)}")
111
 
112
  ## ------------------------------ LOAD MODELS ------------------------------
113
 
114
  def load_face_analyser_model(name="buffalo_l"):
115
  global FACE_ANALYSER
116
  if FACE_ANALYSER is None:
117
- print("Loading face analyser model...")
118
  FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=PROVIDER)
119
  FACE_ANALYSER.prepare(
120
  ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH
121
  )
122
- print("Face analyser model loaded.")
123
 
124
 
125
- def load_face_swapper_model(model_path="./assets/pretrained_models/inswapper_128.onnx"): # Renamed arg for clarity
126
  global FACE_SWAPPER
127
  if FACE_SWAPPER is None:
128
- print(f"Loading face swapper model from {model_path}...")
129
  batch = int(BATCH_SIZE) if device == "cuda" else 1
130
- FACE_SWAPPER = Inswapper(model_file=model_path, batch_size=batch, providers=PROVIDER)
131
- print("Face swapper model loaded.")
132
 
133
 
134
- def load_face_parser_model(model_path="./assets/pretrained_models/79999_iter.pth"): # Renamed arg for clarity
135
  global FACE_PARSER
136
  if FACE_PARSER is None:
137
- print(f"Loading face parsing model from {model_path}...")
138
- FACE_PARSER = init_parsing_model(model_path, device=device)
139
- print("Face parsing model loaded.")
140
 
141
- # Pre-load models at startup
142
  load_face_analyser_model()
143
  load_face_swapper_model()
144
- # Face parser and enhancer are loaded on demand by process function
145
 
146
  ## ------------------------------ MAIN PROCESS ------------------------------
 
 
147
  def process(
148
- # input_type, # REMOVED - hardcoded to "Image"
149
  image_path,
150
- video_path, # Will be None as UI is hidden
151
- directory_path, # Will be None as UI is hidden
152
  source_path,
153
  output_path,
154
  output_name,
155
- keep_output_sequence, # From hidden UI, relevant for video
156
- condition, # Swap condition from UI
157
- age, # From UI, visibility controlled
158
- distance, # From hidden UI (specific face)
159
- face_enhancer_name, # From UI dropdown
160
- enable_face_parser, # From UI checkbox
161
- mask_includes, # From UI dropdown
162
- mask_soft_kernel_ui, # Renamed to avoid clash with global, from UI (hidden)
163
- mask_soft_iterations_ui, # Renamed, from UI
164
- blur_amount, # From UI slider
165
- erode_amount, # From UI slider
166
- face_scale, # From UI slider
167
- enable_laplacian_blend, # From UI checkbox
168
- crop_top, # From UI slider
169
- crop_bott, # From UI slider
170
- crop_left, # From UI slider
171
- crop_right, # From UI slider
172
- *specifics_components, # Tuple of Gradio components for specific faces (hidden UI)
173
  ):
174
- global WORKSPACE, OUTPUT_FILE, PREVIEW, FACE_ENHANCER_MODEL, FACE_PARSER
175
-
 
176
  WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
177
- input_type = "Image" # Hardcoded
178
 
179
- # Use UI values for mask kernel and iterations
180
- current_mask_soft_kernel = int(mask_soft_kernel_ui)
181
- current_mask_soft_iterations = int(mask_soft_iterations_ui)
182
 
183
-
184
- def ui_before(): # Updates for UI elements during processing
185
  return (
186
- gr.update(visible=True, value=PREVIEW), # preview_image (Output image)
187
- gr.update(interactive=False), # Corresponds to output_directory_button in original swap_outputs
188
- gr.update(interactive=False), # Corresponds to output_video_button in original swap_outputs
189
- gr.update(visible=False), # preview_video (Output video)
190
  )
191
 
192
- def ui_after(): # Updates for UI elements after successful image processing
193
  return (
194
- gr.update(visible=True, value=PREVIEW), # preview_image
195
- gr.update(interactive=True), # output_directory_button
196
- gr.update(interactive=True), # output_video_button (though for image, this might stay hidden)
197
- gr.update(visible=False), # preview_video
198
  )
199
 
200
- def ui_after_vid(): # Updates for UI elements after successful video processing (currently unused)
201
  return (
202
- gr.update(visible=False), # preview_image
203
- gr.update(interactive=True), # output_directory_button
204
- gr.update(interactive=True), # output_video_button
205
- gr.update(value=OUTPUT_FILE, visible=True), # preview_video
206
  )
207
 
208
  start_time = time.time()
209
- total_exec_time = lambda st: divmod(time.time() - st, 60)
210
- get_finsh_text = lambda st: f"✔️ Completed in {int(total_exec_time(st)[0])} min {int(total_exec_time(st)[1])} sec."
 
 
211
 
212
- try: # Wrap main processing in try-except
213
- yield "### \n 🌀 Ensuring face analyser model is loaded...", *ui_before()
214
- load_face_analyser_model() # Ensures it's loaded, doesn't reload if already there
 
215
 
216
- yield "### \n ⚙️ Ensuring face swapper model is loaded...", *ui_before()
217
- load_face_swapper_model() # Ensures it's loaded
218
 
219
- global FACE_ENHANCER_MODEL # Ensure we are using the global for the loaded model
220
- if face_enhancer_name != "None": # String "None" from dropdown
221
- if face_enhancer_name not in cv2_interpolations:
222
- yield f"### \n 💡 Loading {face_enhancer_name} model...", *ui_before()
223
- FACE_ENHANCER_MODEL = load_face_enhancer_model(name=face_enhancer_name, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  else:
225
- FACE_ENHANCER_MODEL = None
 
 
 
 
 
 
 
 
226
 
227
- if enable_face_parser:
228
- yield "### \n 📀 Loading face parsing model...", *ui_before()
229
- load_face_parser_model() # Ensures it's loaded
230
-
231
- includes = mask_regions_to_list(mask_includes)
232
-
233
- # Specifics components are passed, but their values will be None as UI is hidden
234
- # The logic for 'Specific Face' condition might not be fully reachable if it relies on UI input for these.
235
- # For now, we assume `condition` will not be "Specific Face" or handles None for sources/specifics.
236
- # If `specifics_components` are needed, their .value would be accessed.
237
- # Since they are hidden, this part of logic for "Specific Face" may need review if that condition is used.
238
- # For now, `sources` and `specific_face_targets_from_ui` will be empty if condition != "Specific Face".
239
- sources_from_ui = []
240
- specific_face_targets_from_ui = []
241
- # If specific_components were populated (e.g. if UI was visible):
242
- # half_len = len(specifics_components) // 2
243
- # sources_from_ui = [comp.value for comp in specifics_components[:half_len] if comp.value is not None]
244
- # specific_face_targets_from_ui = [comp.value for comp in specifics_components[half_len:] if comp.value is not None]
245
-
246
-
247
- if crop_top > crop_bott: crop_top, crop_bott = crop_bott, crop_top
248
- if crop_left > crop_right: crop_left, crop_right = crop_right, crop_left
249
- crop_mask_dims = (crop_top, 511-crop_bott, crop_left, 511-crop_right) # Renamed for clarity
250
-
251
- # Inner function for the core swapping logic (on a sequence of image paths)
252
- def swap_process_on_sequence(image_path_sequence):
253
- nonlocal PREVIEW # Allow modification of PREVIEW
254
- yield "### \n 🧿 Analysing face data...", *ui_before()
255
-
256
- current_source_data = None
257
- if condition == "Specific Face":
258
- # This branch is problematic if UI for specifics is hidden,
259
- # as sources_from_ui and specific_face_targets_from_ui will be empty or contain Nones.
260
- # The get_analysed_data must handle this.
261
- # For now, assuming `distance` (from hidden slider) is the primary input.
262
- # This path needs robust handling if "Specific Face" is ever re-enabled or used programmatically.
263
- print("Warning: 'Specific Face' condition selected, but UI for specific faces is hidden.")
264
- # `source_path` (the single source image) might be used as a fallback or primary source here.
265
- # This part of the logic is unclear without knowing how `get_analysed_data` uses `source_data` for "Specific Face"
266
- # when `specifics_components` are effectively None.
267
- # Assuming it might use `source_path` if other specifics are missing.
268
- # If `sources_from_ui` and `specific_face_targets_from_ui` are empty, this will likely fail or misbehave.
269
- # For safety, if "Specific Face" is chosen and specifics are empty, one might default to another behavior or error.
270
- if not source_path: # If even the main source_path is missing
271
- yield "### \n ❌ 'Specific Face' requires at least one source face image.", *ui_after()
272
- raise ValueError("'Specific Face' requires at least one source face image.")
273
- # Simplified: if specific_face_targets_from_ui is empty, it implies swapping all faces in target with source_path
274
- # This is a guess; original logic for `source_data = ((sources, specifics), distance)` needs `sources` and `specifics`
275
- # For now, we'll pass what we have.
276
- # `sources` would be a list of numpy arrays for source faces from specific_face UI.
277
- # `specifics` would be a list of numpy arrays for target faces from specific_face UI.
278
- # Since these UI elements are hidden, they will be None or empty.
279
- # The original `specifics` variable in `process` was from `*specifics_components`.
280
- # We should use `sources_from_ui` and `specific_face_targets_from_ui` here.
281
- source_data = ((sources_from_ui, specific_face_targets_from_ui), distance)
282
-
283
- else: # For other conditions like "Age", "Gender", etc.
284
- if not source_path:
285
- yield "### \n ❌ Source face image is required for this swap condition.", *ui_after()
286
- raise ValueError("Source face image is required for this swap condition.")
287
- source_data = source_path, age # `age` from UI
288
-
289
- analysed_targets, analysed_sources, whole_frame_list, num_faces_per_frame = get_analysed_data(
290
- FACE_ANALYSER,
291
- image_path_sequence, # List of image file paths
292
- source_data,
293
- swap_condition=condition,
294
- detect_condition=DETECT_CONDITION,
295
- scale=face_scale
296
- )
297
 
298
- if not analysed_targets: # No faces detected in target, or other issue
299
- yield "### \n ⚠️ No target faces found or error in analysis. Cannot proceed with swap.", *ui_after()
300
- return # Stop processing for this sequence
 
 
 
 
 
 
 
301
 
 
 
 
 
302
 
303
- yield "### \n 🧶 Generating faces...", *ui_before()
304
- preds, matrs = [], []
305
- batch_count = 0
306
- # Ensure whole_frame_list, analysed_targets, analysed_sources are not empty/None before batch_forward
307
- if not whole_frame_list or not analysed_targets or analysed_sources is None: # analysed_sources can be tricky (single vs multiple)
308
- yield "### \n ⚠️ Missing data for face generation. Cannot proceed.", *ui_after()
309
- return
310
 
 
 
 
 
 
 
 
 
311
 
312
- for batch_pred, batch_matr in FACE_SWAPPER.batch_forward(whole_frame_list, analysed_targets, analysed_sources):
313
- preds.extend(batch_pred)
314
- matrs.extend(batch_matr)
315
- EMPTY_CACHE()
316
- batch_count += 1
317
- if USE_CUDA and batch_pred: # Check if batch_pred is not empty
318
- image_grid = create_image_grid(batch_pred, size=128) # Ensure batch_pred is list of images
319
- PREVIEW = image_grid[:, :, ::-1]
320
- yield f"### \n 🧩 Generating face Batch {batch_count}", *ui_before()
321
-
322
- generated_len = len(preds)
323
- if generated_len == 0:
324
- yield "### \n ⚠️ No faces were generated. Check source and target images/faces.", *ui_after()
325
- return # Nothing to enhance or parse or paste
326
-
327
- if FACE_ENHANCER_MODEL is not None and face_enhancer_name != "None": # Check model object
328
- yield f"### \n 🎲 Upscaling faces with {face_enhancer_name}...", *ui_before()
329
- # tqdm description should be clear
330
- for idx, pred_img in tqdm(enumerate(preds), total=generated_len, desc=f"Upscaling with {face_enhancer_name}"):
331
- enhancer_model_obj, enhancer_model_runner_func = FACE_ENHANCER_MODEL # Unpack
332
- enhanced_pred = enhancer_model_runner_func(pred_img, enhancer_model_obj)
333
- preds[idx] = cv2.resize(enhanced_pred, (512,512))
334
- EMPTY_CACHE()
335
 
336
- parsed_masks = [None] * generated_len # Initialize with Nones
337
- if enable_face_parser and FACE_PARSER is not None:
338
- yield "### \n 🎨 Face-parsing mask...", *ui_before()
339
- temp_masks_list = [] # To collect batches of masks
340
- parse_batch_count = 0
341
- for batch_mask_data in get_parsed_mask(FACE_PARSER, preds, classes=includes, device=device, batch_size=BATCH_SIZE, softness=int(current_mask_soft_iterations)):
342
- temp_masks_list.append(batch_mask_data) # batch_mask_data is likely a numpy array of masks
343
- EMPTY_CACHE()
344
- parse_batch_count += 1
345
- if len(batch_mask_data) > 0: # Check if batch_mask_data is not empty
346
- # Assuming batch_mask_data is a list/array of single-channel masks
347
- # For create_image_grid, masks might need to be converted to 3-channel grayscale if they are not already
348
- displayable_masks = []
349
- for msk in batch_mask_data:
350
- if msk.ndim == 2: displayable_masks.append(cv2.cvtColor(msk, cv2.COLOR_GRAY2BGR))
351
- elif msk.ndim == 3 and msk.shape[2] == 1: displayable_masks.append(cv2.cvtColor(msk, cv2.COLOR_GRAY2BGR))
352
- else: displayable_masks.append(msk) # Assume it's already displayable
353
-
354
- if displayable_masks:
355
- image_grid = create_image_grid(displayable_masks, size=128)
356
- PREVIEW = image_grid[:, :, ::-1]
357
- yield f"### \n 🪙 Face parsing Batch {parse_batch_count}", *ui_before()
358
- if temp_masks_list: # If any masks were generated
359
- parsed_masks = np.concatenate(temp_masks_list, axis=0)
360
-
361
-
362
- split_preds = split_list_by_lengths(preds, num_faces_per_frame)
363
- del preds
364
- split_matrs = split_list_by_lengths(matrs, num_faces_per_frame)
365
- del matrs
366
- # Ensure parsed_masks has the correct structure for split_list_by_lengths
367
- # If parsed_masks is a single concatenated array, it needs to be a list of masks first if not already.
368
- # Assuming get_parsed_mask and concatenate result in a flat list/array that split_list can handle.
369
- if isinstance(parsed_masks, np.ndarray) and parsed_masks.ndim > 1 and len(parsed_masks) == generated_len : # Check if it's an array of masks
370
- parsed_masks_list = [parsed_masks[i] for i in range(generated_len)]
371
- elif isinstance(parsed_masks, list) and len(parsed_masks) == generated_len:
372
- parsed_masks_list = parsed_masks
373
- else: # Fallback if structure is unexpected or it remained all Nones
374
- parsed_masks_list = [None] * generated_len
375
-
376
- split_masks = split_list_by_lengths(parsed_masks_list, num_faces_per_frame)
377
- del parsed_masks, parsed_masks_list
378
-
379
-
380
- yield "### \n 🧿 Pasting back...", *ui_before()
381
- def post_process_frame(frame_idx, frame_img_path, current_split_preds, current_split_matrs, current_split_masks):
382
- whole_img = cv2.imread(frame_img_path)
383
- if whole_img is None:
384
- print(f"Error: Could not read frame for pasting: {frame_img_path}")
385
- return
386
-
387
- blend_method = 'laplacian' if enable_laplacian_blend else 'linear'
388
- # Ensure frame_idx is within bounds of all lists
389
- if frame_idx < len(current_split_preds) and \
390
- frame_idx < len(current_split_matrs) and \
391
- frame_idx < len(current_split_masks):
392
-
393
- for p_img, m_data, mask_img in zip(current_split_preds[frame_idx], current_split_matrs[frame_idx], current_split_masks[frame_idx]):
394
- p_resized = cv2.resize(p_img, (512,512))
395
- mask_resized = cv2.resize(mask_img, (512,512)) if mask_img is not None else None
396
- m_data_scaled = m_data / 0.25 # Transformation matrix scaling
397
- whole_img = paste_to_whole(p_resized, whole_img, m_data_scaled, mask=mask_resized,
398
- crop_mask=crop_mask_dims, blend_method=blend_method,
399
- blur_amount=blur_amount, erode_amount=erode_amount)
400
- cv2.imwrite(frame_img_path, whole_img) # Overwrite the frame in the sequence
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  with concurrent.futures.ThreadPoolExecutor() as executor:
403
- futures = [executor.submit(post_process_frame, idx, img_path, split_preds, split_matrs, split_masks)
404
- for idx, img_path in enumerate(image_path_sequence)]
 
 
 
405
  for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Pasting back"):
406
- future.result() # Wait for completion and raise exceptions if any
 
 
 
 
 
 
 
 
 
 
 
407
 
408
- # IMAGE INPUT PROCESSING
409
- if input_type == "Image":
410
- if not image_path:
411
- yield "### \n ❌ Target image not provided.", *ui_after()
412
- return
413
 
414
- os.makedirs(output_path, exist_ok=True)
415
- # Create a working copy of the target image for processing
416
- # to avoid modifying the original if it's in a restricted location.
417
- base_name, ext = os.path.splitext(os.path.basename(image_path))
418
- processing_image_filename = f"processing_copy_{base_name}{ext}"
419
- processing_image_path = os.path.join(output_path, processing_image_filename)
420
 
421
- try:
422
- shutil.copyfile(image_path, processing_image_path)
423
- except Exception as e:
424
- yield f"### \n ❌ Failed to copy target image for processing: {str(e)}", *ui_after()
425
- return
426
-
427
- # Call swap_process_on_sequence with a list containing the single image path
428
- for info_update in swap_process_on_sequence([processing_image_path]):
429
- yield info_update
430
-
431
- # Define final output file path (e.g., result.png)
432
- final_output_file_path = os.path.join(output_path, output_name + ".png") # Assuming PNG output for images
433
- # Move the processed image (which overwrote processing_image_path) to the final path
434
- try:
435
- if os.path.exists(final_output_file_path) and final_output_file_path != processing_image_path:
436
- os.remove(final_output_file_path) # Remove if exists and is different file
437
- shutil.move(processing_image_path, final_output_file_path)
438
- except Exception as e:
439
- yield f"### \n ❌ Failed to save final image: {str(e)}", *ui_after()
440
- # Try to provide the temp file if move fails
441
- if os.path.exists(processing_image_path):
442
- final_output_file_path = processing_image_path # Fallback to the temp file
443
- else: # No file available
444
- return
445
-
446
-
447
- OUTPUT_FILE = final_output_file_path
448
- WORKSPACE = output_path
449
- # Load the final image for PREVIEW
450
- final_image_preview = cv2.imread(OUTPUT_FILE)
451
- if final_image_preview is not None:
452
- PREVIEW = final_image_preview[:, :, ::-1]
453
- else: # Fallback if reading the final output fails
454
- PREVIEW = None
455
- yield "### \n ⚠️ Could not load final image for preview.", *ui_after()
456
- # Still yield finish text
457
- yield get_finsh_text(start_time), *ui_after()
458
- return
459
-
460
- yield get_finsh_text(start_time), *ui_after()
461
-
462
- # VIDEO INPUT PROCESSING (Currently Unreachable via UI)
463
- elif input_type == "Video":
464
- if not video_path:
465
- yield "### \n ❌ Target video not provided.", *ui_after_vid() # Use ui_after_vid for consistency
466
- return
467
-
468
- temp_sequence_path = os.path.join(output_path, output_name, "sequence")
469
- os.makedirs(temp_sequence_path, exist_ok=True)
470
-
471
- yield "### \n ⌛ Extracting video frames...", *ui_before()
472
- extracted_image_paths = []
473
- cap = cv2.VideoCapture(video_path)
474
- frame_idx = 0
475
- while True:
476
- ret, frame = cap.read()
477
- if not ret: break
478
- frame_file_path = os.path.join(temp_sequence_path, f"frame_{frame_idx:06d}.jpg") # Padded frame numbers
479
- cv2.imwrite(frame_file_path, frame)
480
- extracted_image_paths.append(frame_file_path)
481
- frame_idx += 1
482
- cap.release()
483
- # cv2.destroyAllWindows() # Not needed for backend processing
484
-
485
- if not extracted_image_paths:
486
- yield "### \n ❌ Video is empty or could not extract frames.", *ui_after_vid()
487
- if os.path.exists(temp_sequence_path): shutil.rmtree(temp_sequence_path)
488
- return
489
-
490
- for info_update in swap_process_on_sequence(extracted_image_paths):
491
- yield info_update
492
-
493
- yield "### \n ⌛ Merging sequence...", *ui_before()
494
- output_video_file_path = os.path.join(output_path, output_name + ".mp4")
495
- # Ensure merge_img_sequence_from_ref handles cases where video_path might be an UploadFile object
496
- original_video_for_ref = video_path.name if hasattr(video_path, 'name') else video_path
497
-
498
- merge_img_sequence_from_ref(original_video_for_ref, extracted_image_paths, output_video_file_path)
499
-
500
- if os.path.exists(temp_sequence_path) and not keep_output_sequence:
501
- yield "### \n ⌛ Removing temporary files...", *ui_before()
502
- shutil.rmtree(temp_sequence_path)
503
-
504
- WORKSPACE = output_path
505
- OUTPUT_FILE = output_video_file_path
506
- # For video, PREVIEW is handled by ui_after_vid making preview_video visible with OUTPUT_FILE
507
-
508
- yield get_finsh_text(start_time), *ui_after_vid()
509
-
510
- # DIRECTORY INPUT PROCESSING (Currently Unreachable via UI)
511
- elif input_type == "Directory":
512
- # ... (Directory processing logic, similar structure to video) ...
513
- # Ensure it uses swap_process_on_sequence
514
- yield "### \n ⚠️ Directory processing is not fully implemented in this UI path.", *ui_after()
515
- return
516
-
517
- # STREAM INPUT PROCESSING (Currently Unreachable via UI)
518
- elif input_type == "Stream":
519
- # ... (Stream processing logic) ...
520
- yield "### \n ⚠️ Stream processing is not implemented.", *ui_after()
521
- return
522
 
523
- except Exception as e:
524
- import traceback
525
- traceback.print_exc()
526
- yield f"### \n 🔥 An error occurred: {str(e)}", *ui_after() # Use ui_after for image mode fallback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
 
529
  ## ------------------------------ GRADIO FUNC ------------------------------
530
- # update_radio is not called as input_type.change is removed
 
531
  def update_radio(value):
532
- if value == "Image": return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
533
- elif value == "Video": return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
534
- elif value == "Directory": return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
535
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) # Default to image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
  def swap_option_changed(value):
538
- age_visible = bool(value and value.startswith("Age"))
539
- # specific_face group is always hidden, source_image_input visibility depends on whether age is shown
540
- source_input_visible = True # Generally source image is needed unless it's specific face mode without main source.
541
- # Since "Specific Face" is hidden, source_image_input should generally be visible.
542
- if value == "Specific Face": # This option is removed from UI, but keep logic if used internally
543
- source_input_visible = False # For "Specific Face", the individual src images are used.
 
 
 
 
 
 
 
544
 
545
- return gr.update(visible=age_visible), gr.update(visible=False), gr.update(visible=source_input_visible)
546
 
547
- def video_changed(video_file_obj): # Input is Gradio's FileData object for gr.Video
548
  sliders_update = gr.Slider.update
 
549
  number_update = gr.Number.update
550
 
551
- if video_file_obj is None or not hasattr(video_file_obj, 'name') or video_file_obj.name is None:
552
- return sliders_update(minimum=0, maximum=0, value=0), \
553
- sliders_update(minimum=1, maximum=1, value=1), \
554
- number_update(value=1)
555
- video_path = video_file_obj.name # Get filepath from FileData
 
556
  try:
557
- if not os.path.exists(video_path):
558
- print(f"Video path from Gradio object does not exist: {video_path}")
559
- return sliders_update(minimum=0, maximum=0, value=0), \
560
- sliders_update(minimum=1, maximum=1, value=1), \
561
- number_update(value=1)
562
  clip = VideoFileClip(video_path)
563
- fps = clip.fps if clip.fps is not None else 30
564
- total_frames = clip.reader.nframes if clip.reader.nframes is not None else 0
565
- max_slider = total_frames if total_frames > 0 else 1
566
  clip.close()
567
- return sliders_update(minimum=0, maximum=max_slider, value=0, interactive=True), \
568
- sliders_update(minimum=0, maximum=max_slider, value=max_slider, interactive=True), \
569
- number_update(value=fps)
570
- except Exception as e:
571
- print(f"Error processing video for metadata: {e}")
572
- return sliders_update(value=0, minimum=0, maximum=0), \
573
- sliders_update(value=0, minimum=1, maximum=1), \
574
- number_update(value=1)
575
-
576
- def analyse_settings_changed(det_cond, det_size, det_thresh): # Args renamed for clarity
577
- yield "### \n ⌛ Applying new detection values..."
578
- global FACE_ANALYSER, DETECT_CONDITION, DETECT_SIZE, DETECT_THRESH # Ensure globals are updated
579
- DETECT_CONDITION = det_cond
580
- DETECT_SIZE = int(det_size)
581
- DETECT_THRESH = float(det_thresh)
582
- # Force reload of analyser with new settings
583
- FACE_ANALYSER = None # Set to None to force re-initialization
584
- load_face_analyser_model()
585
- yield f"### \n ✔️ Applied: Cond:{det_cond}, Size:{det_size}, Thresh:{det_thresh}"
586
 
587
- def stop_running():
588
- global STREAMER # Streamer is for hidden stream type
589
- if hasattr(STREAMER, "stop"): STREAMER.stop()
590
- STREAMER = None
591
- # Optionally, could try to interrupt the current 'process' if it's in a separate thread.
592
- # For Gradio's default handling, returning "Cancelled" and having the event in `cancels` list is key.
593
- return "Processing cancelled by user."
594
-
595
- def slider_changed(show_preview_frame, video_file_obj, frame_idx_val):
596
- if not show_preview_frame or video_file_obj is None or not hasattr(video_file_obj, 'name') or video_file_obj.name is None:
597
- return gr.Image.update(value=None, visible=show_preview_frame if show_preview_frame else False), gr.Video.update(visible=not show_preview_frame if show_preview_frame is not None else True)
598
-
599
- video_path = video_file_obj.name
600
- try:
601
- if not os.path.exists(video_path): return gr.Image.update(value=None), gr.Video.update()
602
- clip = VideoFileClip(video_path)
603
- time_sec = frame_idx_val / clip.fps if clip.fps and clip.fps > 0 else 0
604
- if time_sec > clip.duration: time_sec = clip.duration
605
- frame_arr = clip.get_frame(time_sec)
606
- clip.close()
607
- return gr.Image.update(value=frame_arr, visible=True), gr.Video.update(visible=False)
608
- except Exception as e:
609
- print(f"Error in slider_changed: {e}")
610
- return gr.Image.update(value=None, visible=True), gr.Video.update(visible=False)
611
 
612
- def trim_and_reload(video_file_obj, out_dir, out_name_base, start_f, stop_f): # Args renamed
613
- if video_file_obj is None or not hasattr(video_file_obj, 'name') or video_file_obj.name is None:
614
- return None, "### \n 🔥 Video not provided for trimming."
615
-
616
- original_video_path = video_file_obj.name
617
- os.makedirs(out_dir, exist_ok=True)
618
- # Use a unique name for the trimmed video to avoid overwriting
619
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
620
- trimmed_video_filename = f"{os.path.splitext(out_name_base)[0]}_trimmed_{timestamp}.mp4"
621
- trimmed_video_full_path = os.path.join(out_dir, trimmed_video_filename)
622
-
623
- yield original_video_path, f"### \n 🌈 Trimming video frame {start_f} to {stop_f}..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
  try:
625
- if not os.path.exists(original_video_path):
626
- raise ValueError(f"Original video for trimming not found: {original_video_path}")
627
-
628
- # trim_video should return the path to the new trimmed video
629
- new_trimmed_path = trim_video(original_video_path, trimmed_video_full_path, start_f, stop_f)
630
- yield new_trimmed_path, "### \n ✔️ Video trimmed and reloaded." # Return path for gr.Video
631
  except Exception as e:
632
- print(f"Video trimming error: {e}")
633
- yield original_video_path, f"### \n 🔥 Video trimming failed: {str(e)}. See console."
634
 
635
 
636
- def load_latest_image_from_repo(repo_choice_val):
637
- dataset_repo1_env = os.environ.get("DATASET_REPO") # Env var for repo1
638
- dataset_repo2_env = os.environ.get("DATASET_REPO2") # Env var for repo2
 
639
  hf_token = os.environ.get("HF_TOKEN")
640
 
641
- if not hf_token: return None, "❌ 環境変数 HF_TOKEN が設定されていません"
 
642
 
643
- target_repo_id = None
644
- if repo_choice_val == "repo1": target_repo_id = dataset_repo1_env
645
- elif repo_choice_val == "repo2": target_repo_id = dataset_repo2_env
646
-
647
- if not target_repo_id:
648
- return None, f"❌ 選択されたリポジトリ ({repo_choice_val}) の環境変数 (DATASET_REPO/DATASET_REPO2) 設定です"
649
 
650
  try:
651
- last_update_url = f"https://huggingface.co/datasets/{target_repo_id}/resolve/main/images/last_update.txt"
 
652
  headers = {'Authorization': f'Bearer {hf_token}'}
 
 
653
  print(f"Fetching last_update.txt from {last_update_url}")
654
- response_txt = requests.get(last_update_url, headers=headers, timeout=10)
655
- response_txt.raise_for_status()
656
- image_file_url = response_txt.text.strip() # This should be the full URL from the text file
657
-
658
- if not image_file_url: return None, "❌ last_update.txt が空か無効なURLです"
659
- print(f"Image URL from last_update.txt: {image_file_url}")
660
-
661
- print(f"Fetching image from {image_file_url}")
662
- response_img = requests.get(image_file_url, headers=headers, timeout=20) # Longer timeout for image
663
- response_img.raise_for_status()
664
- pil_image = Image.open(BytesIO(response_img.content))
665
- print("✔️ 画像の取得に成功しました")
666
- return pil_image, "✔️ 最新の画像をロードしました"
667
- except requests.exceptions.RequestException as e:
668
- print(f"RequestException: {str(e)}")
669
- return None, f"❌ ネットワーク/HTTPエラー: {str(e)}"
 
 
 
 
670
  except Exception as e:
671
- print(f"Exception: {str(e)}")
672
- return None, f"❌ 一般エラー: {str(e)}"
 
 
 
 
 
 
 
673
 
674
  ## ------------------------------ GRADIO GUI ------------------------------
 
675
  css = """
676
- .gradio-container { width: 100%; margin: 0 auto !important; padding: 20px !important; max-width: 100% !important; }
 
 
 
 
 
 
 
 
677
  """
678
 
679
  with gr.Blocks(css=css) as interface:
 
680
  with gr.Row():
681
- with gr.Column(scale=0.5, min_width=100): # Left column
682
- with gr.Group(): # Input Area Group
683
- with gr.Group(visible=True) as input_image_group: # Image input always visible
684
- repo_choice_radio = gr.Radio(["repo1", "repo2"], label="リポジトリを選択", value="repo1")
685
- target_load_button = gr.Button("TARGET_LOAD", variant="primary")
686
- target_load_info = gr.Markdown(value="...", visible=True)
687
- image_input = gr.Image(label="Target Image", interactive=True, type="filepath")
688
- # Hidden input_type radio, defaults to "Image"
689
- input_type_radio_hidden = gr.Radio(["Image", "Video"], label="Target Type", value="Image", visible=False)
690
-
691
- with gr.Group(visible=False) as input_video_group: # Video input hidden
692
- video_input = gr.Video(label="Target Video", interactive=True) # Value will be FileData obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
  with gr.Accordion("Trim video", open=False):
694
- set_slider_range_btn = gr.Button("Set frame range")
695
- show_trim_preview_btn = gr.Checkbox(label="Show frame on slider change", value=True)
696
- video_fps_num = gr.Number(value=30, interactive=False, label="Fps", visible=False)
697
- start_frame_slider = gr.Slider(minimum=0, maximum=1, value=0, step=1, label="Start Frame")
698
- end_frame_slider = gr.Slider(minimum=0, maximum=1, value=1, step=1, label="End Frame")
699
- trim_and_reload_btn = gr.Button("Trim and Reload")
700
-
701
- with gr.Group(visible=False) as input_directory_group: # Directory input hidden
702
- direc_input_text = gr.Text(label="Path", interactive=True)
703
-
704
- source_image_input = gr.Image(label="Source face", type="filepath", interactive=True)
705
-
706
- with gr.Group(visible=False) as specific_face_group: # Specific face selection hidden
707
- # Dynamically create TabItems for specific faces (these will be hidden)
708
- # Need to store these components if they were to be used.
709
- # For now, they are defined but not collected into a list for `swap_inputs` in a new way
710
- # The `exec` below defines `src1, trg1, ...` etc. in the current scope.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
  for i in range(NUM_OF_SRC_SPECIFIC):
712
  idx = i + 1
713
- exec(
714
- f"with gr.Tab(label='({idx})'):\n"
715
- f"\twith gr.Row():\n"
716
- f"\t\tsrc{idx} = gr.Image(interactive=True, type='numpy', label='Source Face {idx}')\n"
717
- f"\t\ttrg{idx} = gr.Image(interactive=True, type='numpy', label='Specific Face {idx}')"
718
- )
719
- distance_slider_specific = gr.Slider(minimum=0, maximum=2, value=0.6, interactive=True, label="Distance")
 
 
 
 
 
 
 
 
 
720
 
721
- with gr.Column(scale=0.5, min_width=100): # Right column
722
- with gr.Row():
723
- swap_button = gr.Button("Swap", variant="primary")
724
- cancel_button = gr.Button("Cancel") # Should always be interactive
725
 
726
- preview_image_output = gr.Image(label="Output", interactive=False) # For image result
727
- save_button_hf = gr.Button("Save", variant="primary") # To save to HF dataset
728
- preview_video_output = gr.Video(label="Output", interactive=False, visible=False) # For video result
729
 
730
- with gr.Row(): # Buttons to open output locations
731
- # Visibility of these buttons could be controlled based on WORKSPACE/OUTPUT_FILE
732
- output_directory_open_button = gr.Button("💚 Open Output Dir", interactive=True, visible=True) # Simplified
733
- output_file_open_button = gr.Button("💘 Open Output File", interactive=True, visible=True) # Simplified
734
 
735
- info_markdown = gr.Markdown(value="...") # For status messages
736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737
  with gr.Tab("Swap Condition"):
738
- swap_option_dropdown = gr.Dropdown(swap_options_list_ui, # Use UI-specific list
739
- value=swap_options_list_ui[0] if swap_options_list_ui else None,
740
- label="Swap Condition", interactive=True, show_label=False)
741
- age_number_input = gr.Number(value=25, label="Value for Age", interactive=True, visible=False)
742
-
743
- with gr.Tab("Detection Settings", visible=False): # Hidden Tab
744
- detect_condition_dropdown = gr.Dropdown(detect_conditions, label="Condition", value=DETECT_CONDITION)
745
- detection_size_number = gr.Number(label="Detection Size", value=DETECT_SIZE)
746
- detection_threshold_number = gr.Number(label="Detection Threshold", value=DETECT_THRESH)
747
- apply_detection_settings_btn = gr.Button("Apply Detection Settings")
748
-
749
- with gr.Tab("Output Settings", visible=False): # Hidden Tab
750
- output_directory_text = gr.Text(label="Output Directory", value=DEF_OUTPUT_PATH)
751
- output_name_text = gr.Text(label="Output Name", value="Result")
752
- keep_output_sequence_check = gr.Checkbox(label="Keep output sequence (for video)", value=True)
753
-
754
- with gr.Tab("Other Settings"): # Visible Tab
755
- face_scale_slider = gr.Slider(label="Face Scale", minimum=0, maximum=2, value=0.98, interactive=True)
756
- face_enhancer_name_dropdown = gr.Dropdown(FACE_ENHANCER_LIST, label="Face Enhancer", value="GFPGAN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757
 
758
  with gr.Accordion("Advanced Mask", open=False):
759
- enable_face_parser_mask_check = gr.Checkbox(label="Enable Face Parsing", value=True)
760
- mask_include_dropdown = gr.Dropdown(list(mask_regions.keys()), value=MASK_INCLUDE, multiselect=True, label="Include Regions")
761
- mask_soft_kernel_number = gr.Number(label="Soft Erode Kernel", value=MASK_SOFT_KERNEL, minimum=3, visible=False) # Hidden as per original
762
- mask_soft_iterations_number = gr.Number(label="Soft Erode Iterations", value=MASK_SOFT_ITERATIONS, minimum=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
 
764
  with gr.Accordion("Crop Mask", open=False):
765
- crop_top_slider = gr.Slider(label="Top", minimum=0, maximum=511, value=0, step=1)
766
- crop_bott_slider = gr.Slider(label="Bottom", minimum=0, maximum=511, value=511, step=1)
767
- crop_left_slider = gr.Slider(label="Left", minimum=0, maximum=511, value=0, step=1)
768
- crop_right_slider = gr.Slider(label="Right", minimum=0, maximum=511, value=511, step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
- erode_amount_slider = gr.Slider(label="Mask Erode", minimum=0, maximum=1, value=MASK_ERODE_AMOUNT, step=0.05)
771
- blur_amount_slider = gr.Slider(label="Mask Blur", minimum=0, maximum=1, value=MASK_BLUR_AMOUNT, step=0.05)
772
- enable_laplacian_blend_check = gr.Checkbox(label="Laplacian Blending", value=True)
773
 
774
  ## ------------------------------ GRADIO EVENTS ------------------------------
775
- # Events for hidden video UI - they will still be set up but might not be triggered by user
776
- set_slider_range_event = set_slider_range_btn.click(video_changed, inputs=[video_input], outputs=[start_frame_slider, end_frame_slider, video_fps_num])
777
- trim_and_reload_event = trim_and_reload_btn.click(fn=trim_and_reload, inputs=[video_input, output_directory_text, output_name_text, start_frame_slider, end_frame_slider], outputs=[video_input, info_markdown])
778
- start_frame_slider.release(fn=slider_changed, inputs=[show_trim_preview_btn, video_input, start_frame_slider], outputs=[preview_image_output, preview_video_output])
779
- end_frame_slider.release(fn=slider_changed, inputs=[show_trim_preview_btn, video_input, end_frame_slider], outputs=[preview_image_output, preview_video_output])
780
-
781
- # Input type change is disabled as UI is hidden
782
- # input_type_radio_hidden.change(...)
783
-
784
- swap_option_dropdown.change(swap_option_changed, inputs=[swap_option_dropdown], outputs=[age_number_input, specific_face_group, source_image_input])
785
- apply_detection_settings_btn.click(analyse_settings_changed, inputs=[detect_condition_dropdown, detection_size_number, detection_threshold_number], outputs=[info_markdown])
786
-
787
- # Collect specific face components (src1, trg1, etc.) which are defined by exec() earlier
788
- _src_specific_components_tuple = ()
789
- _s_names = []
790
- for i in range(NUM_OF_SRC_SPECIFIC): _s_names.append(f"src{i+1}")
791
- for i in range(NUM_OF_SRC_SPECIFIC): _s_names.append(f"trg{i+1}")
792
- exec(f"_src_specific_components_tuple = ({','.join(_s_names)})")
793
-
794
- swap_inputs_list = [
795
- image_input, video_input, direc_input_text, source_image_input,
796
- output_directory_text, output_name_text, keep_output_sequence_check,
797
- swap_option_dropdown, age_number_input, distance_slider_specific,
798
- face_enhancer_name_dropdown, enable_face_parser_mask_check, mask_include_dropdown,
799
- mask_soft_kernel_number, mask_soft_iterations_number, # UI names
800
- blur_amount_slider, erode_amount_slider, face_scale_slider, enable_laplacian_blend_check,
801
- crop_top_slider, crop_bott_slider, crop_left_slider, crop_right_slider,
802
- *_src_specific_components_tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803
  ]
804
 
805
- # Outputs for the process function's yield.
806
- # Matches: message, preview_image_update, dir_btn_update, vid_btn_update, preview_vid_update
807
- # The ui_before/after functions return updates for these specific components in order.
808
- swap_outputs_list = [
809
- info_markdown,
810
- preview_image_output, # Updated by ui_before[0] etc.
811
- output_directory_open_button, # Updated by ui_before[1] etc. (interactive toggle)
812
- output_file_open_button, # Updated by ui_before[2] etc. (interactive toggle)
813
- preview_video_output # Updated by ui_before[3] etc.
814
  ]
815
 
816
- swap_event = swap_button.click(fn=process, inputs=swap_inputs_list, outputs=swap_outputs_list, show_progress=True)
 
 
817
 
818
- cancel_event = cancel_button.click(fn=stop_running, inputs=None, outputs=[info_markdown],
819
- cancels=[swap_event, trim_and_reload_event, set_slider_range_event]) # Removed slider events from cancel as they are minor
820
-
821
- output_directory_open_button.click(lambda: open_directory(path=WORKSPACE) if WORKSPACE and os.path.isdir(WORKSPACE) else print(f"Workspace '{WORKSPACE}' not set or not a directory."), inputs=None, outputs=None)
822
- output_file_open_button.click(lambda: open_directory(path=OUTPUT_FILE) if OUTPUT_FILE and os.path.exists(OUTPUT_FILE) else print(f"Output file '{OUTPUT_FILE}' not set or does not exist."), inputs=None, outputs=None)
823
-
824
-
825
- def save_to_huggingface_dataset(image_numpy_array):
826
- if image_numpy_array is None: return "❌ 出力画像がありません。"
827
- save_repo_id = os.environ.get("SAVE_REPO")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828
  hf_token = os.environ.get("HF_TOKEN")
829
- if not save_repo_id or not hf_token: return "❌ 環境変数 SAVE_REPO または HF_TOKEN が未設定"
 
 
830
 
831
  try:
832
- cache_dir = os.path.join(os.getcwd(), ".hf_datasets_cache_save")
833
- os.makedirs(cache_dir, exist_ok=True)
834
  try:
835
- dataset = load_dataset(save_repo_id, split='train', use_auth_token=hf_token, cache_dir=cache_dir)
836
- except Exception: # If dataset doesn't exist or other loading error
837
- print(f"Dataset {save_repo_id} not found or error loading, creating new one.")
838
- ds_features = Features({"image": DatasetImage(), "timestamp": Value("string")})
839
- dataset = Dataset.from_dict({"image": [], "timestamp": []}, features=ds_features)
840
-
841
- pil_img_to_save = Image.fromarray(image_numpy_array.astype('uint8'), 'RGB')
842
- temp_img_path = os.path.join(os.getcwd(), "temp_hf_upload.png")
843
- pil_img_to_save.save(temp_img_path)
844
-
845
- new_data_entry = Dataset.from_dict({
846
- "image": [temp_img_path],
847
- "timestamp": [str(datetime.datetime.now())]},
848
- features=dataset.features # Use existing dataset features
849
- )
850
- updated_dataset = concatenate_datasets([dataset, new_data_entry])
851
- updated_dataset.push_to_hub(save_repo_id, token=hf_token)
852
- os.remove(temp_img_path)
 
 
 
 
 
 
 
 
853
  return "✔️ 画像をHugging Faceデータセットに保存しました"
854
  except Exception as e:
855
- if 'temp_img_path' in locals() and os.path.exists(temp_img_path): os.remove(temp_img_path)
856
- return f"❌ HFへの保存中にエラー: {str(e)}"
857
-
858
- save_button_hf.click(fn=save_to_huggingface_dataset, inputs=[preview_image_output], outputs=[info_markdown], show_progress=True)
859
-
860
- def load_target_from_hf_with_choice(repo_choice_val):
861
- pil_img, msg = load_latest_image_from_repo(repo_choice_val)
862
- # image_input (gr.Image) can take PIL image directly
863
- return pil_img, f"### {msg}" # Update image_input and target_load_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864
 
865
- target_load_button.click(fn=load_target_from_hf_with_choice, inputs=[repo_choice_radio], outputs=[image_input, target_load_info], show_progress=True)
866
- repo_choice_radio.change(fn=load_target_from_hf_with_choice, inputs=[repo_choice_radio], outputs=[image_input, target_load_info], show_progress=True)
867
 
868
  if __name__ == "__main__":
869
- if USE_COLAB: print("Running in colab mode, share=True might be set by launch()")
870
- # Use debug=True for more detailed Gradio logs if needed
871
- interface.queue(concurrency_count=2, max_size=20).launch(share=USE_COLAB, debug=True)
 
 
12
  import onnxruntime
13
  import numpy as np
14
  import gradio as gr
15
+ import threading
16
+ import queue
17
 
18
  from datasets import Dataset, Features, Image as DatasetImage, Value, load_dataset, concatenate_datasets
19
  from PIL import Image
 
25
  from moviepy.editor import VideoFileClip
26
 
27
  from face_swapper import Inswapper, paste_to_whole
28
+ from face_analyser import detect_conditions, get_analysed_data, swap_options_list
29
  from face_parsing import init_parsing_model, get_parsed_mask, mask_regions, mask_regions_to_list
30
  from face_enhancer import get_available_enhancer_names, load_face_enhancer_model, cv2_interpolations
31
  from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_list_by_lengths, merge_img_sequence_from_ref, create_image_grid
 
35
  parser = argparse.ArgumentParser(description="Free Face Swapper")
36
  parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
37
  parser.add_argument("--batch_size", help="Gpu batch size", default=32)
38
+ parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=True)
39
  parser.add_argument(
40
  "--colab", action="store_true", help="Enable colab mode", default=False
41
  )
 
49
  BATCH_SIZE = int(user_args.batch_size)
50
  WORKSPACE = None
51
  OUTPUT_FILE = None
52
+ CURRENT_FRAME = None
53
+ STREAMER = None
54
  DETECT_CONDITION = "best detection"
55
  DETECT_SIZE = 640
56
  DETECT_THRESH = 0.7
57
+ NUM_OF_SRC_SPECIFIC = 10
58
  MASK_INCLUDE = [
59
  "Skin",
60
  "R-Eyebrow",
 
73
 
74
  FACE_SWAPPER = None
75
  FACE_ANALYSER = None
76
+ FACE_ENHANCER = "GFPGAN"
 
77
  FACE_PARSER = None
78
+ FACE_ENHANCER_LIST = ["None"]
79
  FACE_ENHANCER_LIST.extend(get_available_enhancer_names())
80
  FACE_ENHANCER_LIST.extend(cv2_interpolations)
81
 
 
 
 
82
  ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
83
+ # Note: Non CUDA users may change settings here
84
+
85
  PROVIDER = ["CPUExecutionProvider"]
86
  if USE_CUDA:
87
  available_providers = onnxruntime.get_available_providers()
 
89
  if "CUDAExecutionProvider" in available_providers:
90
  print("\n********** Running on CUDA **********\n")
91
  PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
92
+
93
  else:
94
+ USE_CUDA = False
95
  print("\n********** CUDA unavailable running on CPU **********\n")
96
  else:
97
+ USE_CUDA = False
98
  print("\n********** Running on CPU **********\n")
99
 
100
  device = "cuda" if USE_CUDA else "cpu"
101
+ EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
102
+ print(onnxruntime.get_available_providers())
103
+ print(torch.cuda.is_available())
104
+ print(torch.cuda.device_count())
105
+ print(torch.cuda.current_device())
106
+ print(torch.cuda.get_device_name(0))
 
 
 
107
 
108
  ## ------------------------------ LOAD MODELS ------------------------------
109
 
110
  def load_face_analyser_model(name="buffalo_l"):
111
  global FACE_ANALYSER
112
  if FACE_ANALYSER is None:
 
113
  FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=PROVIDER)
114
  FACE_ANALYSER.prepare(
115
  ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH
116
  )
 
117
 
118
 
119
+ def load_face_swapper_model(path="./assets/pretrained_models/inswapper_128.onnx"):
120
  global FACE_SWAPPER
121
  if FACE_SWAPPER is None:
 
122
  batch = int(BATCH_SIZE) if device == "cuda" else 1
123
+ FACE_SWAPPER = Inswapper(model_file=path, batch_size=batch, providers=PROVIDER)
 
124
 
125
 
126
+ def load_face_parser_model(path="./assets/pretrained_models/79999_iter.pth"):
127
  global FACE_PARSER
128
  if FACE_PARSER is None:
129
+ FACE_PARSER = init_parsing_model(path, device=device)
130
+
 
131
 
 
132
  load_face_analyser_model()
133
  load_face_swapper_model()
 
134
 
135
  ## ------------------------------ MAIN PROCESS ------------------------------
136
+
137
+
138
  def process(
139
+ input_type,
140
  image_path,
141
+ video_path,
142
+ directory_path,
143
  source_path,
144
  output_path,
145
  output_name,
146
+ keep_output_sequence,
147
+ condition,
148
+ age,
149
+ distance,
150
+ face_enhancer_name,
151
+ enable_face_parser,
152
+ mask_includes,
153
+ mask_soft_kernel,
154
+ mask_soft_iterations,
155
+ blur_amount,
156
+ erode_amount,
157
+ face_scale,
158
+ enable_laplacian_blend,
159
+ crop_top,
160
+ crop_bott,
161
+ crop_left,
162
+ crop_right,
163
+ *specifics,
164
  ):
165
+ global WORKSPACE
166
+ global OUTPUT_FILE
167
+ global PREVIEW
168
  WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
 
169
 
170
+ ## ------------------------------ GUI UPDATE FUNC ------------------------------
 
 
171
 
172
+ def ui_before():
 
173
  return (
174
+ gr.update(visible=True, value=PREVIEW),
175
+ gr.update(interactive=False),
176
+ gr.update(interactive=False),
177
+ gr.update(visible=False),
178
  )
179
 
180
+ def ui_after():
181
  return (
182
+ gr.update(visible=True, value=PREVIEW),
183
+ gr.update(interactive=True),
184
+ gr.update(interactive=True),
185
+ gr.update(visible=False),
186
  )
187
 
188
+ def ui_after_vid():
189
  return (
190
+ gr.update(visible=False),
191
+ gr.update(interactive=True),
192
+ gr.update(interactive=True),
193
+ gr.update(value=OUTPUT_FILE, visible=True),
194
  )
195
 
196
  start_time = time.time()
197
+ total_exec_time = lambda start_time: divmod(time.time() - start_time, 60)
198
+ get_finsh_text = lambda start_time: f"✔️ Completed in {int(total_exec_time(start_time)[0])} min {int(total_exec_time(start_time)[1])} sec."
199
+
200
+ ## ------------------------------ PREPARE INPUTS & LOAD MODELS ------------------------------
201
 
202
+
203
+
204
+ yield "### \n 🌀 Loading face analyser model...", *ui_before()
205
+ load_face_analyser_model()
206
 
207
+ yield "### \n ⚙️ Loading face swapper model...", *ui_before()
208
+ load_face_swapper_model()
209
 
210
+ if face_enhancer_name != "NONE":
211
+ if face_enhancer_name not in cv2_interpolations:
212
+ yield f"### \n 💡 Loading {face_enhancer_name} model...", *ui_before()
213
+ FACE_ENHANCER = load_face_enhancer_model(name=face_enhancer_name, device=device)
214
+ else:
215
+ FACE_ENHANCER = None
216
+
217
+ if enable_face_parser:
218
+ yield "### \n 📀 Loading face parsing model...", *ui_before()
219
+ load_face_parser_model()
220
+
221
+ includes = mask_regions_to_list(mask_includes)
222
+ specifics = list(specifics)
223
+ half = len(specifics) // 2
224
+ sources = specifics[:half]
225
+ specifics = specifics[half:]
226
+ if crop_top > crop_bott:
227
+ crop_top, crop_bott = crop_bott, crop_top
228
+ if crop_left > crop_right:
229
+ crop_left, crop_right = crop_right, crop_left
230
+ crop_mask = (crop_top, 511-crop_bott, crop_left, 511-crop_right)
231
+
232
+ def swap_process(image_sequence):
233
+ ## ------------------------------ CONTENT CHECK ------------------------------
234
+
235
+
236
+ yield "### \n 🧿 Analysing face data...", *ui_before()
237
+ if condition != "Specific Face":
238
+ source_data = source_path, age
239
  else:
240
+ source_data = ((sources, specifics), distance)
241
+ analysed_targets, analysed_sources, whole_frame_list, num_faces_per_frame = get_analysed_data(
242
+ FACE_ANALYSER,
243
+ image_sequence,
244
+ source_data,
245
+ swap_condition=condition,
246
+ detect_condition=DETECT_CONDITION,
247
+ scale=face_scale
248
+ )
249
 
250
+ ## ------------------------------ SWAP FUNC ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
+ yield "### \n 🧶 Generating faces...", *ui_before()
253
+ preds = []
254
+ matrs = []
255
+ count = 0
256
+ global PREVIEW
257
+ for batch_pred, batch_matr in FACE_SWAPPER.batch_forward(whole_frame_list, analysed_targets, analysed_sources):
258
+ preds.extend(batch_pred)
259
+ matrs.extend(batch_matr)
260
+ EMPTY_CACHE()
261
+ count += 1
262
 
263
+ if USE_CUDA:
264
+ image_grid = create_image_grid(batch_pred, size=128)
265
+ PREVIEW = image_grid[:, :, ::-1]
266
+ yield f"### \n 🧩 Generating face Batch {count}", *ui_before()
267
 
268
+ ## ------------------------------ FACE ENHANCEMENT ------------------------------
 
 
 
 
 
 
269
 
270
+ generated_len = len(preds)
271
+ if face_enhancer_name != "NONE":
272
+ yield f"### \n 🎲 Upscaling faces with {face_enhancer_name}...", *ui_before()
273
+ for idx, pred in tqdm(enumerate(preds), total=generated_len, desc=f"Upscaling with {face_enhancer_name}"):
274
+ enhancer_model, enhancer_model_runner = FACE_ENHANCER
275
+ pred = enhancer_model_runner(pred, enhancer_model)
276
+ preds[idx] = cv2.resize(pred, (512,512))
277
+ EMPTY_CACHE()
278
 
279
+ ## ------------------------------ FACE PARSING ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
+ if enable_face_parser:
282
+ yield "### \n 🎨 Face-parsing mask...", *ui_before()
283
+ masks = []
284
+ count = 0
285
+ for batch_mask in get_parsed_mask(FACE_PARSER, preds, classes=includes, device=device, batch_size=BATCH_SIZE, softness=int(mask_soft_iterations)):
286
+ masks.append(batch_mask)
287
+ EMPTY_CACHE()
288
+ count += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
+ if len(batch_mask) > 1:
291
+ image_grid = create_image_grid(batch_mask, size=128)
292
+ PREVIEW = image_grid[:, :, ::-1]
293
+ yield f"### \n 🪙 Face parsing Batch {count}", *ui_before()
294
+ masks = np.concatenate(masks, axis=0) if len(masks) >= 1 else masks
295
+ else:
296
+ masks = [None] * generated_len
297
+
298
+ ## ------------------------------ SPLIT LIST ------------------------------
299
+
300
+ split_preds = split_list_by_lengths(preds, num_faces_per_frame)
301
+ del preds
302
+ split_matrs = split_list_by_lengths(matrs, num_faces_per_frame)
303
+ del matrs
304
+ split_masks = split_list_by_lengths(masks, num_faces_per_frame)
305
+ del masks
306
+
307
+ ## ------------------------------ PASTE-BACK ------------------------------
308
+
309
+ yield "### \n 🧿 Pasting back...", *ui_before()
310
+ def post_process(frame_idx, frame_img, split_preds, split_matrs, split_masks, enable_laplacian_blend, crop_mask, blur_amount, erode_amount):
311
+ whole_img_path = frame_img
312
+ whole_img = cv2.imread(whole_img_path)
313
+ blend_method = 'laplacian' if enable_laplacian_blend else 'linear'
314
+ for p, m, mask in zip(split_preds[frame_idx], split_matrs[frame_idx], split_masks[frame_idx]):
315
+ p = cv2.resize(p, (512,512))
316
+ mask = cv2.resize(mask, (512,512)) if mask is not None else None
317
+ m /= 0.25
318
+ whole_img = paste_to_whole(p, whole_img, m, mask=mask, crop_mask=crop_mask, blend_method=blend_method, blur_amount=blur_amount, erode_amount=erode_amount)
319
+ cv2.imwrite(whole_img_path, whole_img)
320
+
321
+ def concurrent_post_process(image_sequence, *args):
322
  with concurrent.futures.ThreadPoolExecutor() as executor:
323
+ futures = []
324
+ for idx, frame_img in enumerate(image_sequence):
325
+ future = executor.submit(post_process, idx, frame_img, *args)
326
+ futures.append(future)
327
+
328
  for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Pasting back"):
329
+ result = future.result()
330
+
331
+ concurrent_post_process(
332
+ image_sequence,
333
+ split_preds,
334
+ split_matrs,
335
+ split_masks,
336
+ enable_laplacian_blend,
337
+ crop_mask,
338
+ blur_amount,
339
+ erode_amount
340
+ )
341
 
 
 
 
 
 
342
 
343
+ ## ------------------------------ IMAGE ------------------------------
 
 
 
 
 
344
 
345
+ if input_type == "Image":
346
+ target = cv2.imread(image_path)
347
+ output_file = os.path.join(output_path, output_name + ".png")
348
+ cv2.imwrite(output_file, target)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
+ for info_update in swap_process([output_file]):
351
+ yield info_update
352
+
353
+ OUTPUT_FILE = output_file
354
+ WORKSPACE = output_path
355
+ PREVIEW = cv2.imread(output_file)[:, :, ::-1]
356
+
357
+ yield get_finsh_text(start_time), *ui_after()
358
+
359
+ ## ------------------------------ VIDEO ------------------------------
360
+
361
+ elif input_type == "Video":
362
+ temp_path = os.path.join(output_path, output_name, "sequence")
363
+ os.makedirs(temp_path, exist_ok=True)
364
+
365
+ yield "### \n ⌛ Extracting video frames...", *ui_before()
366
+ image_sequence = []
367
+ cap = cv2.VideoCapture(video_path)
368
+ curr_idx = 0
369
+ while True:
370
+ ret, frame = cap.read()
371
+ if not ret:break
372
+ frame_path = os.path.join(temp_path, f"frame_{curr_idx}.jpg")
373
+ cv2.imwrite(frame_path, frame)
374
+ image_sequence.append(frame_path)
375
+ curr_idx += 1
376
+ cap.release()
377
+ cv2.destroyAllWindows()
378
+
379
+ for info_update in swap_process(image_sequence):
380
+ yield info_update
381
+
382
+ yield "### \n ⌛ Merging sequence...", *ui_before()
383
+ output_video_path = os.path.join(output_path, output_name + ".mp4")
384
+ merge_img_sequence_from_ref(video_path, image_sequence, output_video_path)
385
+
386
+ if os.path.exists(temp_path) and not keep_output_sequence:
387
+ yield "### \n ⌛ Removing temporary files...", *ui_before()
388
+ shutil.rmtree(temp_path)
389
+
390
+ WORKSPACE = output_path
391
+ OUTPUT_FILE = output_video_path
392
+
393
+ yield get_finsh_text(start_time), *ui_after_vid()
394
+
395
+ ## ------------------------------ DIRECTORY ------------------------------
396
+
397
+ elif input_type == "Directory":
398
+ extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"]
399
+ temp_path = os.path.join(output_path, output_name)
400
+ if os.path.exists(temp_path):
401
+ shutil.rmtree(temp_path)
402
+ os.mkdir(temp_path)
403
+
404
+ file_paths =[]
405
+ for file_path in glob.glob(os.path.join(directory_path, "*")):
406
+ if any(file_path.lower().endswith(ext) for ext in extensions):
407
+ img = cv2.imread(file_path)
408
+ new_file_path = os.path.join(temp_path, os.path.basename(file_path))
409
+ cv2.imwrite(new_file_path, img)
410
+ file_paths.append(new_file_path)
411
+
412
+ for info_update in swap_process(file_paths):
413
+ yield info_update
414
+
415
+ PREVIEW = cv2.imread(file_paths[-1])[:, :, ::-1]
416
+ WORKSPACE = temp_path
417
+ OUTPUT_FILE = file_paths[-1]
418
+
419
+ yield get_finsh_text(start_time), *ui_after()
420
+
421
+ ## ------------------------------ STREAM ------------------------------
422
+
423
+ elif input_type == "Stream":
424
+ pass
425
 
426
 
427
  ## ------------------------------ GRADIO FUNC ------------------------------
428
+
429
+
430
  def update_radio(value):
431
+ if value == "Image":
432
+ return (
433
+ gr.update(visible=True),
434
+ gr.update(visible=False),
435
+ gr.update(visible=False),
436
+ )
437
+ elif value == "Video":
438
+ return (
439
+ gr.update(visible=False),
440
+ gr.update(visible=True),
441
+ gr.update(visible=False),
442
+ )
443
+ elif value == "Directory":
444
+ return (
445
+ gr.update(visible=False),
446
+ gr.update(visible=False),
447
+ gr.update(visible=True),
448
+ )
449
+ elif value == "Stream":
450
+ return (
451
+ gr.update(visible=False),
452
+ gr.update(visible=False),
453
+ gr.update(visible=True),
454
+ )
455
+
456
 
457
  def swap_option_changed(value):
458
+ if value.startswith("Age"):
459
+ return (
460
+ gr.update(visible=True),
461
+ gr.update(visible=False),
462
+ gr.update(visible=True),
463
+ )
464
+ elif value == "Specific Face":
465
+ return (
466
+ gr.update(visible=False),
467
+ gr.update(visible=True),
468
+ gr.update(visible=False),
469
+ )
470
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
471
 
 
472
 
473
+ def video_changed(video_path):
474
  sliders_update = gr.Slider.update
475
+ button_update = gr.Button.update
476
  number_update = gr.Number.update
477
 
478
+ if video_path is None:
479
+ return (
480
+ sliders_update(minimum=0, maximum=0, value=0),
481
+ sliders_update(minimum=1, maximum=1, value=1),
482
+ number_update(value=1),
483
+ )
484
  try:
 
 
 
 
 
485
  clip = VideoFileClip(video_path)
486
+ fps = clip.fps
487
+ total_frames = clip.reader.nframes
 
488
  clip.close()
489
+ return (
490
+ sliders_update(minimum=0, maximum=total_frames, value=0, interactive=True),
491
+ sliders_update(
492
+ minimum=0, maximum=total_frames, value=total_frames, interactive=True
493
+ ),
494
+ number_update(value=fps),
495
+ )
496
+ except:
497
+ return (
498
+ sliders_update(value=0),
499
+ sliders_update(value=0),
500
+ number_update(value=1),
501
+ )
 
 
 
 
 
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
+ def analyse_settings_changed(detect_condition, detection_size, detection_threshold):
505
+ yield "### \n Applying new values..."
506
+ global FACE_ANALYSER
507
+ global DETECT_CONDITION
508
+ DETECT_CONDITION = detect_condition
509
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
510
+ FACE_ANALYSER.prepare(
511
+ ctx_id=0,
512
+ det_size=(int(detection_size), int(detection_size)),
513
+ det_thresh=float(detection_threshold),
514
+ )
515
+ yield f"### \n ✔️ Applied detect condition:{detect_condition}, detection size: {detection_size}, detection threshold: {detection_threshold}"
516
+
517
+
518
+ def stop_running():
519
+ global STREAMER
520
+ if hasattr(STREAMER, "stop"):
521
+ STREAMER.stop()
522
+ STREAMER = None
523
+ return "Cancelled"
524
+
525
+
526
+ def slider_changed(show_frame, video_path, frame_index):
527
+ if not show_frame:
528
+ return None, None
529
+ if video_path is None:
530
+ return None, None
531
+ clip = VideoFileClip(video_path)
532
+ frame = clip.get_frame(frame_index / clip.fps)
533
+ frame_array = np.array(frame)
534
+ clip.close()
535
+ return gr.Image.update(value=frame_array, visible=True), gr.Video.update(
536
+ visible=False
537
+ )
538
+
539
+
540
+ def trim_and_reload(video_path, output_path, output_name, start_frame, stop_frame):
541
+ yield video_path, f"### \n 🌈 Trimming video frame {start_frame} to {stop_frame}..."
542
  try:
543
+ output_path = os.path.join(output_path, output_name)
544
+ trimmed_video = trim_video(video_path, output_path, start_frame, stop_frame)
545
+ yield trimmed_video, "### \n ✔️ Video trimmed and reloaded."
 
 
 
546
  except Exception as e:
547
+ print(e)
548
+ yield video_path, "### \n 🔥 Video trimming failed. See console for more info."
549
 
550
 
551
+ #
552
+ def load_latest_image_from_repo(repo_choice): # 引数を追加
553
+ dataset_repo = os.environ.get("DATASET_REPO")
554
+ dataset_repo2 = os.environ.get("DATASET_REPO2") # dataset_repo2 を追加
555
  hf_token = os.environ.get("HF_TOKEN")
556
 
557
+ if not hf_token:
558
+ return None, "❌ 環境変数 HF_TOKEN が設定されていません"
559
 
560
+ if repo_choice == "repo1" and dataset_repo: # 選択されたリポジトリに応じて切り替え
561
+ target_repo = dataset_repo
562
+ elif repo_choice == "repo2" and dataset_repo2:
563
+ target_repo = dataset_repo2
564
+ else:
565
+ return None, f"❌ 選択されたリポジトリの環境変数が設定されていません (選択: {repo_choice})"
566
 
567
  try:
568
+ # last_update.txt の URL を作成
569
+ last_update_url = f"https://huggingface.co/datasets/{target_repo}/resolve/main/images/last_update.txt"
570
  headers = {'Authorization': f'Bearer {hf_token}'}
571
+
572
+ # last_update.txt を取得
573
  print(f"Fetching last_update.txt from {last_update_url}")
574
+ response = requests.get(last_update_url, headers=headers)
575
+ print(f"Status code for last_update.txt: {response.status_code}")
576
+ if response.status_code != 200:
577
+ return None, f"❌ last_update.txt を取得できません (HTTP {response.status_code})"
578
+
579
+ image_url = response.text.strip()
580
+ if not image_url:
581
+ return None, "❌ last_update.txt が空です"
582
+ print(f"Image URL from last_update.txt: {image_url}")
583
+
584
+ # 画像を取得
585
+ print(f"Fetching image from {image_url}")
586
+ response = requests.get(image_url, headers=headers)
587
+ print(f"Status code for image: {response.status_code}")
588
+ if response.status_code == 200:
589
+ img = Image.open(BytesIO(response.content))
590
+ print("✔️ 画像の取得に成功しました")
591
+ return img, "✔️ 最新の画像をロードしました"
592
+ else:
593
+ return None, f"❌ 画像を取得できません (HTTP {response.status_code})"
594
  except Exception as e:
595
+ print(f"Exception occurred: {str(e)}")
596
+ return None, f"❌ エラーが発生しました: {str(e)}"
597
+
598
+ def load_target_image():
599
+ img, message = load_latest_image_from_repo()
600
+ if img is None:
601
+ return None, f"### {message}"
602
+ return img, f"### {message}"
603
+
604
 
605
  ## ------------------------------ GRADIO GUI ------------------------------
606
+
607
  css = """
608
+ .gradio-container {
609
+ width: 100%;
610
+ height: 100vh;
611
+ overflow: hidden;
612
+ margin: 0 auto !important;
613
+ padding: 20px !important;
614
+ max-width: 100% !important;
615
+ }
616
+
617
  """
618
 
619
  with gr.Blocks(css=css) as interface:
620
+ #
621
  with gr.Row():
622
+ ##左の列
623
+ with gr.Column(scale=0.5, min_width=100):
624
+
625
+ ##ターゲッ画像入力エ
626
+ with gr.Group():
627
+ with gr.Group(visible=True) as input_image_group:
628
+ # ラジオボタンを追加
629
+ repo_choice_radio = gr.Radio(
630
+ ["repo1", "repo2"],
631
+ label="リポジトリを選択",
632
+ value="repo1" # デフォルト値
633
+ )
634
+
635
+ target_load_button = gr.Button("TARGET_LOAD" , variant="primary")
636
+ target_load_info = gr.Markdown(value="...",visible=False)
637
+ image_input = gr.Image(
638
+ label="Target Image", interactive=True, type="filepath"
639
+ )
640
+
641
+
642
+ input_type = gr.Radio(
643
+ ["Image", "Video"],
644
+ label="Target Type",
645
+ value="Image",
646
+ #visible=False
647
+ )
648
+
649
+ with gr.Group(visible=False) as input_video_group:
650
+ vid_widget = gr.Video if USE_COLAB else gr.Text
651
+ video_input = gr.Video(
652
+ label="Target Video", interactive=True
653
+ )
654
  with gr.Accordion("Trim video", open=False):
655
+ with gr.Column():
656
+ with gr.Row():
657
+ set_slider_range_btn = gr.Button(
658
+ "Set frame range", interactive=True
659
+ )
660
+ show_trim_preview_btn = gr.Checkbox(
661
+ label="Show frame when slider change",
662
+ value=True,
663
+ interactive=True,
664
+ )
665
+
666
+ video_fps = gr.Number(
667
+ value=30,
668
+ interactive=False,
669
+ label="Fps",
670
+ visible=False,
671
+ )
672
+ start_frame = gr.Slider(
673
+ minimum=0,
674
+ maximum=1,
675
+ value=0,
676
+ step=1,
677
+ interactive=True,
678
+ label="Start Frame",
679
+ info="",
680
+ )
681
+ end_frame = gr.Slider(
682
+ minimum=0,
683
+ maximum=1,
684
+ value=1,
685
+ step=1,
686
+ interactive=True,
687
+ label="End Frame",
688
+ info="",
689
+ )
690
+ trim_and_reload_btn = gr.Button(
691
+ "Trim and Reload", interactive=True
692
+ )
693
+
694
+ with gr.Group(visible=False) as input_directory_group:
695
+
696
+ direc_input = gr.Text(label="Path", interactive=True)
697
+ ##画像入力エリア終了
698
+
699
+ ##ソース画像入力エリア
700
+ source_image_input = gr.Image(
701
+ label="Source face", type="filepath", interactive=True
702
+ )
703
+
704
+ with gr.Group(visible=False) as specific_face:
705
  for i in range(NUM_OF_SRC_SPECIFIC):
706
  idx = i + 1
707
+ code = "\n"
708
+ code += f"with gr.Tab(label='({idx})'):"
709
+ code += "\n\twith gr.Row():"
710
+ code += f"\n\t\tsrc{idx} = gr.Image(interactive=True, type='numpy', label='Source Face {idx}')"
711
+ code += f"\n\t\ttrg{idx} = gr.Image(interactive=True, type='numpy', label='Specific Face {idx}')"
712
+ exec(code)
713
+
714
+ distance_slider = gr.Slider(
715
+ minimum=0,
716
+ maximum=2,
717
+ value=0.6,
718
+ interactive=True,
719
+ label="Distance",
720
+ info="Lower distance is more similar and higher distance is less similar to the target face.",
721
+ )
722
+ ##ソース画像入力エリア終了
723
 
 
 
 
 
724
 
 
 
 
725
 
 
 
 
 
726
 
 
727
 
728
+
729
+
730
+ ##右の列
731
+ with gr.Column(scale=0.5, min_width=100):
732
+ ##画像出力パート
733
+ with gr.Row():
734
+ swap_button = gr.Button("Swap", variant="primary")
735
+ cancel_button = gr.Button("Cancel")
736
+
737
+ preview_image = gr.Image(label="Output", interactive=False)
738
+ save_button = gr.Button("Save" , variant="primary") # 新しく追加するボタン
739
+ preview_video = gr.Video(
740
+ label="Output", interactive=False, visible=False
741
+ )
742
+
743
+ with gr.Row():
744
+ output_directory_button = gr.Button(
745
+ "💚", interactive=False, visible=False
746
+ )
747
+ output_video_button = gr.Button(
748
+ "💘", interactive=False, visible=False
749
+ )
750
+ info = gr.Markdown(value="...")
751
+ ##画像出力パート終了
752
+
753
+ ##4つのタブパート
754
  with gr.Tab("Swap Condition"):
755
+ swap_option = gr.Dropdown(
756
+ swap_options_list,
757
+ info="Choose which face or faces in the target image to swap.",
758
+ multiselect=False,
759
+ show_label=False,
760
+ value=swap_options_list[0],
761
+ interactive=True,
762
+ )
763
+ age = gr.Number(
764
+ value=25, label="Value", interactive=True, visible=False
765
+ )
766
+
767
+ with gr.Tab("Detection Settings",visible=False):
768
+ detect_condition_dropdown = gr.Dropdown(
769
+ detect_conditions,
770
+ label="Condition",
771
+ value=DETECT_CONDITION,
772
+ interactive=True,
773
+ info="This condition is only used when multiple faces are detected on source or specific image.",
774
+ )
775
+ detection_size = gr.Number(
776
+ label="Detection Size", value=DETECT_SIZE, interactive=True
777
+ )
778
+ detection_threshold = gr.Number(
779
+ label="Detection Threshold",
780
+ value=DETECT_THRESH,
781
+ interactive=True,
782
+ )
783
+ apply_detection_settings = gr.Button("Apply settings")
784
+
785
+ with gr.Tab("Output Settings",visible=False):
786
+ output_directory = gr.Text(
787
+ label="Output Directory",
788
+ value=DEF_OUTPUT_PATH,
789
+ interactive=True,
790
+ )
791
+ output_name = gr.Text(
792
+ label="Output Name", value="Result", interactive=True
793
+ )
794
+ keep_output_sequence = gr.Checkbox(
795
+ label="Keep output sequence", value=True, interactive=True
796
+ )
797
+
798
+ with gr.Tab("Other Settings"):
799
+ face_scale = gr.Slider(
800
+ label="Face Scale",
801
+ minimum=0,
802
+ maximum=2,
803
+ value=0.98,
804
+ interactive=True,
805
+ )
806
+
807
+ face_enhancer_name = gr.Dropdown(
808
+ FACE_ENHANCER_LIST, label="Face Enhancer", value="GFPGAN", multiselect=False, interactive=True
809
+ )
810
 
811
  with gr.Accordion("Advanced Mask", open=False):
812
+ enable_face_parser_mask = gr.Checkbox(
813
+ label="Enable Face Parsing",
814
+ value=True,
815
+ interactive=True,
816
+ )
817
+
818
+ mask_include = gr.Dropdown(
819
+ mask_regions.keys(),
820
+ value=MASK_INCLUDE,
821
+ multiselect=True,
822
+ label="Include",
823
+ interactive=True,
824
+ )
825
+ mask_soft_kernel = gr.Number(
826
+ label="Soft Erode Kernel",
827
+ value=MASK_SOFT_KERNEL,
828
+ minimum=3,
829
+ interactive=True,
830
+ visible = False
831
+ )
832
+ mask_soft_iterations = gr.Number(
833
+ label="Soft Erode Iterations",
834
+ value=MASK_SOFT_ITERATIONS,
835
+ minimum=0,
836
+ interactive=True,
837
+
838
+ )
839
+
840
 
841
  with gr.Accordion("Crop Mask", open=False):
842
+ crop_top = gr.Slider(label="Top", minimum=0, maximum=511, value=0, step=1, interactive=True)
843
+ crop_bott = gr.Slider(label="Bottom", minimum=0, maximum=511, value=511, step=1, interactive=True)
844
+ crop_left = gr.Slider(label="Left", minimum=0, maximum=511, value=0, step=1, interactive=True)
845
+ crop_right = gr.Slider(label="Right", minimum=0, maximum=511, value=511, step=1, interactive=True)
846
+
847
+
848
+ erode_amount = gr.Slider(
849
+ label="Mask Erode",
850
+ minimum=0,
851
+ maximum=1,
852
+ value=MASK_ERODE_AMOUNT,
853
+ step=0.05,
854
+ interactive=True,
855
+ )
856
+
857
+ blur_amount = gr.Slider(
858
+ label="Mask Blur",
859
+ minimum=0,
860
+ maximum=1,
861
+ value=MASK_BLUR_AMOUNT,
862
+ step=0.05,
863
+ interactive=True,
864
+ )
865
+
866
+ enable_laplacian_blend = gr.Checkbox(
867
+ label="Laplacian Blending",
868
+ value=True,
869
+ interactive=True,
870
+ )
871
+ ##4つのタブパート終了
872
+
873
+
874
+
875
+
876
 
 
 
 
877
 
878
  ## ------------------------------ GRADIO EVENTS ------------------------------
879
+
880
+ set_slider_range_event = set_slider_range_btn.click(
881
+ video_changed,
882
+ inputs=[video_input],
883
+ outputs=[start_frame, end_frame, video_fps],
884
+ )
885
+
886
+ trim_and_reload_event = trim_and_reload_btn.click(
887
+ fn=trim_and_reload,
888
+ inputs=[video_input, output_directory, output_name, start_frame, end_frame],
889
+ outputs=[video_input, info],
890
+ )
891
+
892
+ start_frame_event = start_frame.release(
893
+ fn=slider_changed,
894
+ inputs=[show_trim_preview_btn, video_input, start_frame],
895
+ outputs=[preview_image, preview_video],
896
+ show_progress=True,
897
+ )
898
+
899
+ end_frame_event = end_frame.release(
900
+ fn=slider_changed,
901
+ inputs=[show_trim_preview_btn, video_input, end_frame],
902
+ outputs=[preview_image, preview_video],
903
+ show_progress=True,
904
+ )
905
+
906
+ input_type.change(
907
+ update_radio,
908
+ inputs=[input_type],
909
+ outputs=[input_image_group, input_video_group, input_directory_group],
910
+ )
911
+ swap_option.change(
912
+ swap_option_changed,
913
+ inputs=[swap_option],
914
+ outputs=[age, specific_face, source_image_input],
915
+ )
916
+
917
+ apply_detection_settings.click(
918
+ analyse_settings_changed,
919
+ inputs=[detect_condition_dropdown, detection_size, detection_threshold],
920
+ outputs=[info],
921
+ )
922
+
923
+ src_specific_inputs = []
924
+ gen_variable_txt = ",".join(
925
+ [f"src{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
926
+ + [f"trg{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
927
+ )
928
+ exec(f"src_specific_inputs = ({gen_variable_txt})")
929
+ swap_inputs = [
930
+ input_type,
931
+ image_input,
932
+ video_input,
933
+ direc_input,
934
+ source_image_input,
935
+ output_directory,
936
+ output_name,
937
+ keep_output_sequence,
938
+ swap_option,
939
+ age,
940
+ distance_slider,
941
+ face_enhancer_name,
942
+ enable_face_parser_mask,
943
+ mask_include,
944
+ mask_soft_kernel,
945
+ mask_soft_iterations,
946
+ blur_amount,
947
+ erode_amount,
948
+ face_scale,
949
+ enable_laplacian_blend,
950
+ crop_top,
951
+ crop_bott,
952
+ crop_left,
953
+ crop_right,
954
+ *src_specific_inputs,
955
  ]
956
 
957
+ swap_outputs = [
958
+ info,
959
+ preview_image,
960
+ output_directory_button,
961
+ output_video_button,
962
+ preview_video,
 
 
 
963
  ]
964
 
965
+ swap_event = swap_button.click(
966
+ fn=process, inputs=swap_inputs, outputs=swap_outputs, show_progress=True
967
+ )
968
 
969
+
970
+ cancel_button.click(
971
+ fn=stop_running,
972
+ inputs=None,
973
+ outputs=[info],
974
+ cancels=[
975
+ swap_event,
976
+ trim_and_reload_event,
977
+ set_slider_range_event,
978
+ start_frame_event,
979
+ end_frame_event,
980
+ ],
981
+ show_progress=True,
982
+
983
+ )
984
+ output_directory_button.click(
985
+ lambda: open_directory(path=WORKSPACE), inputs=None, outputs=None
986
+ )
987
+ output_video_button.click(
988
+ lambda: open_directory(path=OUTPUT_FILE), inputs=None, outputs=None
989
+ )
990
+
991
+ # Save ボタンのコールバック関数
992
+ def save_to_huggingface(image):
993
+ import os
994
+ from datasets import Dataset, Features, Image as DatasetImage, Value, load_dataset, concatenate_datasets
995
+ import datetime
996
+ from PIL import Image # 追加
997
+
998
+ save_repo = os.environ.get("SAVE_REPO")
999
  hf_token = os.environ.get("HF_TOKEN")
1000
+
1001
+ if not save_repo or not hf_token:
1002
+ return "❌ 環境変数 SAVE_REPO または HF_TOKEN が設定されていません"
1003
 
1004
  try:
1005
+ # 既存のデータセットをロード
 
1006
  try:
1007
+ ds = load_dataset(save_repo, split='train', use_auth_token=hf_token)
1008
+ except:
1009
+ # データセットがまだ存在しない場合、新規作成
1010
+ ds = Dataset.from_dict({"image": [], "timestamp": []})
1011
+
1012
+ # 画像を保存し、新しいデータを追加
1013
+ image_pil = Image.fromarray(image.astype('uint8'), 'RGB') # numpy 配列を PIL 画像に変換
1014
+ image_path = "temp_output_image.png"
1015
+ image_pil.save(image_path)
1016
+ new_data = Dataset.from_dict({
1017
+ "image": [image_path],
1018
+ "timestamp": [str(datetime.datetime.now())]
1019
+ }, features=Features({
1020
+ "image": DatasetImage(),
1021
+ "timestamp": Value("string"),
1022
+ }))
1023
+
1024
+ # データセットを更新
1025
+ ds = concatenate_datasets([ds, new_data])
1026
+
1027
+ # データセットをプッシュ
1028
+ ds.push_to_hub(save_repo, token=hf_token)
1029
+
1030
+ # 一時ファイルを削除
1031
+ os.remove(image_path)
1032
+
1033
  return "✔️ 画像をHugging Faceデータセットに保存しました"
1034
  except Exception as e:
1035
+ return f"❌ エラーが発生しました: {str(e)}"
1036
+
1037
+ # Save ボタンのイベント設定
1038
+ save_button.click(
1039
+ fn=save_to_huggingface,
1040
+ inputs=[preview_image],
1041
+ outputs=[info],
1042
+ show_progress=True
1043
+ )
1044
+
1045
+ def load_target_image_with_choice(repo_choice): # 引数を追加
1046
+ img, message = load_latest_image_from_repo(repo_choice)
1047
+ if img is None:
1048
+ return None, f"### {message}"
1049
+ return img, f"### {message}"
1050
+
1051
+ target_load_button.click(
1052
+ fn=load_target_image_with_choice, # 関数名を変更
1053
+ inputs=[repo_choice_radio], # ラジオボタンの値を入力として渡す
1054
+ outputs=[image_input, target_load_info],
1055
+ show_progress=True,
1056
+ )
1057
+
1058
+ repo_choice_radio.change( # ラジオボタンの変更イベントを追加
1059
+ fn=load_target_image_with_choice,
1060
+ inputs=[repo_choice_radio],
1061
+ outputs=[image_input, target_load_info],
1062
+ show_progress=True,
1063
+ )
1064
 
 
 
1065
 
1066
  if __name__ == "__main__":
1067
+ if USE_COLAB:
1068
+ print("Running in colab mode")
1069
+
1070
+ interface.queue(concurrency_count=2, max_size=20).launch(share=USE_COLAB)