rayquaza384mega commited on
Commit
4ab5cd7
·
1 Parent(s): 9778ac9

Pseudo-color-6

Browse files
Files changed (1) hide show
  1. app.py +57 -17
app.py CHANGED
@@ -49,6 +49,13 @@ WEIGHT_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
49
  LOGO_PATH = "utils/logo2_transparent.png"
50
  SAVE_EXAMPLES = False
51
 
 
 
 
 
 
 
 
52
  # --- Base directory for all models ---
53
  REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts"
54
  MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token)
@@ -303,24 +310,52 @@ def get_gallery_selection(evt: gr.SelectData):
303
 
304
  # --- Generation Functions ---
305
  @spaces.GPU(duration=120)
306
- def generate_t2i(prompt, num_inference_steps, current_color):
 
 
 
307
  global t2i_pipe
308
  if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
309
  target_model_path = PROMPT_TO_MODEL_MAP.get(prompt, f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000")
310
  t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
311
 
312
- print(f"\n🚀 T2I Task started... | Prompt: '{prompt}'")
313
- image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
314
 
315
- raw_file_path = save_temp_tiff(image_np, prefix="t2i_raw")
316
- display_image = apply_pseudocolor(image_np, current_color)
317
- colorbar_img = generate_colorbar_preview(current_color)
 
318
 
319
- if SAVE_EXAMPLES:
320
- example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
321
- if not os.path.exists(example_filepath): display_image.save(example_filepath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
- return display_image, raw_file_path, image_np, colorbar_img
 
324
 
325
  @spaces.GPU(duration=120)
326
  def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed, current_color):
@@ -491,7 +526,7 @@ seg_examples = load_examples(SEG_EXAMPLE_IMG_DIR)
491
  cls_examples = load_examples(CLS_EXAMPLE_IMG_DIR)
492
 
493
  # --- UI Builders ---
494
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
495
  with gr.Row():
496
  gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False)
497
  gr.Markdown(f"# {MODEL_TITLE}\n{MODEL_DESCRIPTION}")
@@ -499,25 +534,29 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
499
  with gr.Tabs():
500
  # --- TAB 1: Text-to-Image ---
501
  with gr.Tab("Text-to-Image Generation", id="txt2img"):
502
- t2i_raw_state = gr.State(None)
503
  with gr.Row(variant="panel"):
504
  with gr.Column(scale=1, min_width=350):
505
  t2i_prompt = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Search or Type a Prompt", filterable=True, allow_custom_value=True)
506
  t2i_steps = gr.Slider(10, 200, 50, step=1, label="Inference Steps")
 
 
507
  t2i_btn = gr.Button("Generate", variant="primary")
508
  with gr.Column(scale=2):
509
- t2i_out = gr.Image(label="Generated Image", type="pil", interactive=False, show_download_button=False)
 
510
  with gr.Row(equal_height=True):
511
  with gr.Column(scale=2):
512
  t2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (Adjust after generation)")
513
  with gr.Column(scale=2):
514
  t2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
515
  with gr.Column(scale=1):
516
- t2i_dl = gr.DownloadButton(label="Download Raw (.tif)")
517
  t2i_gal = gr.Gallery(value=t2i_examples, label="Examples", columns=6, height="auto")
518
 
519
- t2i_btn.click(generate_t2i, [t2i_prompt, t2i_steps, t2i_color], [t2i_out, t2i_dl, t2i_raw_state, t2i_colorbar])
520
- t2i_color.change(update_single_image_color, [t2i_raw_state, t2i_color], [t2i_out, t2i_colorbar])
 
521
  t2i_gal.select(fn=get_gallery_selection, inputs=None, outputs=t2i_prompt)
522
 
523
  # --- TAB 2: Super-Resolution ---
@@ -594,7 +633,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
594
  with gr.Row(variant="panel"):
595
  with gr.Column(scale=1, min_width=350):
596
  m2i_file = gr.File(label="Upload Mask (.tif)", file_types=['.tif', '.tiff'])
597
- m2i_type = gr.Textbox(label="Cell Type", placeholder="e.g., HeLa")
 
598
  m2i_num = gr.Slider(1, 10, 5, step=1, label="Count")
599
  m2i_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
600
  m2i_seed = gr.Number(label="Seed", value=42)
 
49
  LOGO_PATH = "utils/logo2_transparent.png"
50
  SAVE_EXAMPLES = False
51
 
52
+ # --- CSS for Times New Roman ---
53
+ CUSTOM_CSS = """
54
+ .gradio-container, .gradio-container * {
55
+ font-family: 'Times New Roman', Times, serif !important;
56
+ }
57
+ """
58
+
59
  # --- Base directory for all models ---
60
  REPO_ID = "FluoGen-Group/FluoGen-demo-test-ckpts"
61
  MODELS_ROOT_DIR = snapshot_download(repo_id=REPO_ID, token=hf_token)
 
310
 
311
  # --- Generation Functions ---
312
  @spaces.GPU(duration=120)
313
+ def generate_t2i(prompt, num_inference_steps, num_images, current_color):
314
+ """
315
+ Generates multiple images for Text-to-Image and returns a gallery.
316
+ """
317
  global t2i_pipe
318
  if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
319
  target_model_path = PROMPT_TO_MODEL_MAP.get(prompt, f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/FULL-checkpoint-275000")
320
  t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
321
 
322
+ print(f"\n🚀 T2I Task started... | Prompt: '{prompt}' | Count: {num_images}")
 
323
 
324
+ generated_raw_list = []
325
+ generated_display_images = []
326
+ generated_raw_files = []
327
+ temp_dir = tempfile.mkdtemp()
328
 
329
+ # Generate Batch
330
+ for i in range(int(num_images)):
331
+ # Generate single image
332
+ image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
333
+ generated_raw_list.append(image_np)
334
+
335
+ # Save raw to temp
336
+ raw_name = f"t2i_sample_{i+1}.tif"
337
+ raw_path = os.path.join(temp_dir, raw_name)
338
+ save_data = image_np.astype(np.float32) if image_np.dtype == np.float16 else image_np
339
+ tifffile.imwrite(raw_path, save_data)
340
+ generated_raw_files.append(raw_path)
341
+
342
+ # Create display version
343
+ generated_display_images.append(apply_pseudocolor(image_np, current_color))
344
+
345
+ # Save first image to examples if needed
346
+ if SAVE_EXAMPLES and i == 0:
347
+ example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))
348
+ if not os.path.exists(example_filepath): generated_display_images[0].save(example_filepath)
349
+
350
+ # Zip raw files
351
+ zip_filename = os.path.join(temp_dir, "raw_output_images.zip")
352
+ with zipfile.ZipFile(zip_filename, 'w') as zipf:
353
+ for file in generated_raw_files: zipf.write(file, os.path.basename(file))
354
+
355
+ colorbar_img = generate_colorbar_preview(current_color)
356
 
357
+ # Return: Gallery List, Zip Path, Raw State List, Colorbar
358
+ return generated_display_images, zip_filename, generated_raw_list, colorbar_img
359
 
360
  @spaces.GPU(duration=120)
361
  def run_mask_to_image_generation(mask_file_obj, cell_type, num_images, steps, seed, current_color):
 
526
  cls_examples = load_examples(CLS_EXAMPLE_IMG_DIR)
527
 
528
  # --- UI Builders ---
529
+ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
530
  with gr.Row():
531
  gr.Image(value=LOGO_PATH, width=300, height=200, container=False, interactive=False, show_download_button=False, show_fullscreen_button=False)
532
  gr.Markdown(f"# {MODEL_TITLE}\n{MODEL_DESCRIPTION}")
 
534
  with gr.Tabs():
535
  # --- TAB 1: Text-to-Image ---
536
  with gr.Tab("Text-to-Image Generation", id="txt2img"):
537
+ t2i_raw_state = gr.State(None) # Stores list of arrays
538
  with gr.Row(variant="panel"):
539
  with gr.Column(scale=1, min_width=350):
540
  t2i_prompt = gr.Dropdown(choices=T2I_PROMPTS, value=T2I_PROMPTS[0], label="Search or Type a Prompt", filterable=True, allow_custom_value=True)
541
  t2i_steps = gr.Slider(10, 200, 50, step=1, label="Inference Steps")
542
+ # Added: Number of Images Slider
543
+ t2i_num_images = gr.Slider(1, 9, 4, step=1, label="Number of Images")
544
  t2i_btn = gr.Button("Generate", variant="primary")
545
  with gr.Column(scale=2):
546
+ # Changed: Image to Gallery
547
+ t2i_gallery_out = gr.Gallery(label="Generated Images", columns=3, height="auto")
548
  with gr.Row(equal_height=True):
549
  with gr.Column(scale=2):
550
  t2i_color = gr.Dropdown(choices=COLOR_MAPS, value="Grayscale", label="Pseudocolor (Adjust after generation)")
551
  with gr.Column(scale=2):
552
  t2i_colorbar = gr.Image(label="Colorbar", show_label=False, container=False, height=40, show_download_button=False, interactive=False)
553
  with gr.Column(scale=1):
554
+ t2i_dl = gr.DownloadButton(label="Download All (.zip)")
555
  t2i_gal = gr.Gallery(value=t2i_examples, label="Examples", columns=6, height="auto")
556
 
557
+ t2i_btn.click(generate_t2i, [t2i_prompt, t2i_steps, t2i_num_images, t2i_color], [t2i_gallery_out, t2i_dl, t2i_raw_state, t2i_colorbar])
558
+ # Reuse update_gallery_color since state is now a list
559
+ t2i_color.change(update_gallery_color, [t2i_raw_state, t2i_color], [t2i_gallery_out, t2i_colorbar])
560
  t2i_gal.select(fn=get_gallery_selection, inputs=None, outputs=t2i_prompt)
561
 
562
  # --- TAB 2: Super-Resolution ---
 
633
  with gr.Row(variant="panel"):
634
  with gr.Column(scale=1, min_width=350):
635
  m2i_file = gr.File(label="Upload Mask (.tif)", file_types=['.tif', '.tiff'])
636
+ # Changed: Default value to HeLa
637
+ m2i_type = gr.Textbox(label="Cell Type", value="HeLa", placeholder="e.g., HeLa")
638
  m2i_num = gr.Slider(1, 10, 5, step=1, label="Count")
639
  m2i_steps = gr.Slider(5, 50, 10, step=1, label="Steps")
640
  m2i_seed = gr.Number(label="Seed", value=42)