Hpsoyl commited on
Commit
bc82a4c
·
1 Parent(s): f0616d4

more prompt

Browse files
Files changed (1) hide show
  1. app.py +58 -12
app.py CHANGED
@@ -149,29 +149,43 @@ def update_sr_prompt(model_name):
149
  return "F-actin of COS-7"
150
  return "" # 或者返回一个默认值
151
 
 
 
152
  def load_all_prompts():
153
- prompt_files = [
154
- r"prompts/basic_prompts.json",
155
- r"prompts/others_prompts.json",
156
- r"prompts/hpa_prompts.json"
 
 
 
 
 
 
 
 
 
 
157
  ]
158
 
159
- combined_prompts = []
160
- for file_path in prompt_files:
 
 
161
  try:
162
  if os.path.exists(file_path):
163
  with open(file_path, "r", encoding="utf-8") as f:
164
  data = json.load(f)
165
  if isinstance(data, list):
166
  combined_prompts.extend(data)
167
- print(f"✓ Loaded {len(data)} prompts from: {os.path.basename(file_path)}")
168
- else:
169
- print(f" Warning: File not found: {file_path}")
170
  except Exception as e:
171
  print(f"✗ Error loading {file_path}: {e}")
172
 
173
  if not combined_prompts:
174
- return ["F-actin of COS-7", "ER of COS-7", "Mitochondria of BPAE"]
175
  return combined_prompts
176
  T2I_PROMPTS = load_all_prompts()
177
 
@@ -186,6 +200,7 @@ try:
186
  t2i_tokenizer = CLIPTokenizer.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="tokenizer")
187
  t2i_pipe = DDPMPipeline(unet=t2i_unet, scheduler=t2i_noise_scheduler, text_encoder=t2i_text_encoder, tokenizer=t2i_tokenizer)
188
  t2i_pipe.to(DEVICE)
 
189
  print("✓ Text-to-Image model loaded successfully!")
190
  except Exception as e:
191
  print(f"!!!!!! FATAL: Text-to-Image Model Loading Failed !!!!!!\nError: {e}")
@@ -217,9 +232,40 @@ def swap_controlnet(pipe, target_path):
217
  raise gr.Error(f"Failed to load ControlNet model '{target_path}'. Error: {e}")
218
  return pipe
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  def generate_t2i(prompt, num_inference_steps):
 
221
  if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
222
- print(f"\nTask started... | Prompt: '{prompt}' | Steps: {num_inference_steps}")
 
 
 
 
 
223
  image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
224
  generated_image = numpy_to_pil(image_np)
225
  print("✓ Image generated")
@@ -577,7 +623,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
577
  # sr_prompt_input = gr.Textbox(label="Prompt (e.g., structure name)", value="CCPs of COS-7")
578
  sr_prompt_input = gr.Textbox(
579
  label="Prompt",
580
- value="CCPs of COS-7", # 初始值根据你的默认选择设定
581
  interactive=False
582
  )
583
  sr_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
 
149
  return "F-actin of COS-7"
150
  return "" # 或者返回一个默认值
151
 
152
+ PROMPT_TO_MODEL_MAP = {}
153
+ current_t2i_unet_path = None
154
  def load_all_prompts():
155
+ global PROMPT_TO_MODEL_MAP
156
+ categories = [
157
+ {
158
+ "file": "prompts/basic_prompts.json",
159
+ "model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
160
+ },
161
+ {
162
+ "file": "prompts/others_prompts.json",
163
+ "model": f"{MODELS_ROOT_DIR}/FluoGen-demo-test-ckpts/FULL-checkpoint-275000"
164
+ },
165
+ {
166
+ "file": "prompts/hpa_prompts.json",
167
+ "model": f"{MODELS_ROOT_DIR}/FluoGen-demo-test-ckpts/HPA-checkpoint-40000"
168
+ }
169
  ]
170
 
171
+ combined_prompts = []
172
+ for cat in categories:
173
+ file_path = cat["file"]
174
+ model_path = cat["model"]
175
  try:
176
  if os.path.exists(file_path):
177
  with open(file_path, "r", encoding="utf-8") as f:
178
  data = json.load(f)
179
  if isinstance(data, list):
180
  combined_prompts.extend(data)
181
+ for p in data:
182
+ PROMPT_TO_MODEL_MAP[p] = model_path
183
+ print(f" Loaded {len(data)} prompts from {file_path}")
184
  except Exception as e:
185
  print(f"✗ Error loading {file_path}: {e}")
186
 
187
  if not combined_prompts:
188
+ return ["F-actin of COS-7", "ER of COS-7"]
189
  return combined_prompts
190
  T2I_PROMPTS = load_all_prompts()
191
 
 
200
  t2i_tokenizer = CLIPTokenizer.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="tokenizer")
201
  t2i_pipe = DDPMPipeline(unet=t2i_unet, scheduler=t2i_noise_scheduler, text_encoder=t2i_text_encoder, tokenizer=t2i_tokenizer)
202
  t2i_pipe.to(DEVICE)
203
+ current_t2i_unet_path = T2I_UNET_PATH
204
  print("✓ Text-to-Image model loaded successfully!")
205
  except Exception as e:
206
  print(f"!!!!!! FATAL: Text-to-Image Model Loading Failed !!!!!!\nError: {e}")
 
232
  raise gr.Error(f"Failed to load ControlNet model '{target_path}'. Error: {e}")
233
  return pipe
234
 
235
+ def swap_t2i_unet(pipe, target_unet_path):
236
+ global current_t2i_unet_path
237
+ target_unet_path = os.path.normpath(target_unet_path)
238
+ if current_t2i_unet_path is None or os.path.normpath(current_t2i_unet_path) != target_unet_path:
239
+ print(f"🔄 Swapping T2I UNet to: {target_unet_path}")
240
+ try:
241
+ new_unet = UNet2DModel.from_pretrained(target_unet_path, subfolder="unet").to(DEVICE)
242
+ pipe.unet = new_unet
243
+ current_t2i_unet_path = target_unet_path
244
+ print("✅ UNet swapped successfully.")
245
+ except Exception as e:
246
+ raise gr.Error(f"Failed to load UNet from {target_unet_path}. Error: {e}")
247
+ return pipe
248
+
249
+ # def generate_t2i(prompt, num_inference_steps):
250
+ # if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
251
+ # print(f"\nTask started... | Prompt: '{prompt}' | Steps: {num_inference_steps}")
252
+ # image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
253
+ # generated_image = numpy_to_pil(image_np)
254
+ # print("✓ Image generated")
255
+ # if SAVE_EXAMPLES:
256
+ # example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))0
257
+ # if not os.path.exists(example_filepath):
258
+ # generated_image.save(example_filepath); print(f"✓ New T2I example saved: {example_filepath}")
259
+ # return generated_image
260
  def generate_t2i(prompt, num_inference_steps):
261
+ global t2i_pipe
262
  if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
263
+ target_model_path = PROMPT_TO_MODEL_MAP.get(prompt)
264
+ if target_model_path:
265
+ t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
266
+ else:
267
+ print(f"⚠️ Warning: No specific model mapped for '{prompt}', using current weights.")
268
+ print(f"\n🚀 Task started... | Prompt: '{prompt}' | Model: {current_t2i_unet_path}")
269
  image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
270
  generated_image = numpy_to_pil(image_np)
271
  print("✓ Image generated")
 
623
  # sr_prompt_input = gr.Textbox(label="Prompt (e.g., structure name)", value="CCPs of COS-7")
624
  sr_prompt_input = gr.Textbox(
625
  label="Prompt",
626
+ value="F-actin of COS-7", # 初始值根据你的默认选择设定
627
  interactive=False
628
  )
629
  sr_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")