Spaces:
Sleeping
Sleeping
more prompt
Browse files
app.py
CHANGED
|
@@ -149,29 +149,43 @@ def update_sr_prompt(model_name):
|
|
| 149 |
return "F-actin of COS-7"
|
| 150 |
return "" # 或者返回一个默认值
|
| 151 |
|
|
|
|
|
|
|
| 152 |
def load_all_prompts():
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
]
|
| 158 |
|
| 159 |
-
combined_prompts = []
|
| 160 |
-
for
|
|
|
|
|
|
|
| 161 |
try:
|
| 162 |
if os.path.exists(file_path):
|
| 163 |
with open(file_path, "r", encoding="utf-8") as f:
|
| 164 |
data = json.load(f)
|
| 165 |
if isinstance(data, list):
|
| 166 |
combined_prompts.extend(data)
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
except Exception as e:
|
| 171 |
print(f"✗ Error loading {file_path}: {e}")
|
| 172 |
|
| 173 |
if not combined_prompts:
|
| 174 |
-
return ["F-actin of COS-7", "ER of COS-7"
|
| 175 |
return combined_prompts
|
| 176 |
T2I_PROMPTS = load_all_prompts()
|
| 177 |
|
|
@@ -186,6 +200,7 @@ try:
|
|
| 186 |
t2i_tokenizer = CLIPTokenizer.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="tokenizer")
|
| 187 |
t2i_pipe = DDPMPipeline(unet=t2i_unet, scheduler=t2i_noise_scheduler, text_encoder=t2i_text_encoder, tokenizer=t2i_tokenizer)
|
| 188 |
t2i_pipe.to(DEVICE)
|
|
|
|
| 189 |
print("✓ Text-to-Image model loaded successfully!")
|
| 190 |
except Exception as e:
|
| 191 |
print(f"!!!!!! FATAL: Text-to-Image Model Loading Failed !!!!!!\nError: {e}")
|
|
@@ -217,9 +232,40 @@ def swap_controlnet(pipe, target_path):
|
|
| 217 |
raise gr.Error(f"Failed to load ControlNet model '{target_path}'. Error: {e}")
|
| 218 |
return pipe
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
def generate_t2i(prompt, num_inference_steps):
|
|
|
|
| 221 |
if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
|
| 224 |
generated_image = numpy_to_pil(image_np)
|
| 225 |
print("✓ Image generated")
|
|
@@ -577,7 +623,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 577 |
# sr_prompt_input = gr.Textbox(label="Prompt (e.g., structure name)", value="CCPs of COS-7")
|
| 578 |
sr_prompt_input = gr.Textbox(
|
| 579 |
label="Prompt",
|
| 580 |
-
value="
|
| 581 |
interactive=False
|
| 582 |
)
|
| 583 |
sr_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
|
|
|
|
| 149 |
return "F-actin of COS-7"
|
| 150 |
return "" # 或者返回一个默认值
|
| 151 |
|
| 152 |
+
PROMPT_TO_MODEL_MAP = {}
|
| 153 |
+
current_t2i_unet_path = None
|
| 154 |
def load_all_prompts():
|
| 155 |
+
global PROMPT_TO_MODEL_MAP
|
| 156 |
+
categories = [
|
| 157 |
+
{
|
| 158 |
+
"file": "prompts/basic_prompts.json",
|
| 159 |
+
"model": f"{MODELS_ROOT_DIR}/UNET_T2I_CONTROLNET/checkpoint-285000"
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"file": "prompts/others_prompts.json",
|
| 163 |
+
"model": f"{MODELS_ROOT_DIR}/FluoGen-demo-test-ckpts/FULL-checkpoint-275000"
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"file": "prompts/hpa_prompts.json",
|
| 167 |
+
"model": f"{MODELS_ROOT_DIR}/FluoGen-demo-test-ckpts/HPA-checkpoint-40000"
|
| 168 |
+
}
|
| 169 |
]
|
| 170 |
|
| 171 |
+
combined_prompts = []
|
| 172 |
+
for cat in categories:
|
| 173 |
+
file_path = cat["file"]
|
| 174 |
+
model_path = cat["model"]
|
| 175 |
try:
|
| 176 |
if os.path.exists(file_path):
|
| 177 |
with open(file_path, "r", encoding="utf-8") as f:
|
| 178 |
data = json.load(f)
|
| 179 |
if isinstance(data, list):
|
| 180 |
combined_prompts.extend(data)
|
| 181 |
+
for p in data:
|
| 182 |
+
PROMPT_TO_MODEL_MAP[p] = model_path
|
| 183 |
+
print(f"✓ Loaded {len(data)} prompts from {file_path}")
|
| 184 |
except Exception as e:
|
| 185 |
print(f"✗ Error loading {file_path}: {e}")
|
| 186 |
|
| 187 |
if not combined_prompts:
|
| 188 |
+
return ["F-actin of COS-7", "ER of COS-7"]
|
| 189 |
return combined_prompts
|
| 190 |
T2I_PROMPTS = load_all_prompts()
|
| 191 |
|
|
|
|
| 200 |
t2i_tokenizer = CLIPTokenizer.from_pretrained(T2I_PRETRAINED_MODEL_PATH, subfolder="tokenizer")
|
| 201 |
t2i_pipe = DDPMPipeline(unet=t2i_unet, scheduler=t2i_noise_scheduler, text_encoder=t2i_text_encoder, tokenizer=t2i_tokenizer)
|
| 202 |
t2i_pipe.to(DEVICE)
|
| 203 |
+
current_t2i_unet_path = T2I_UNET_PATH
|
| 204 |
print("✓ Text-to-Image model loaded successfully!")
|
| 205 |
except Exception as e:
|
| 206 |
print(f"!!!!!! FATAL: Text-to-Image Model Loading Failed !!!!!!\nError: {e}")
|
|
|
|
| 232 |
raise gr.Error(f"Failed to load ControlNet model '{target_path}'. Error: {e}")
|
| 233 |
return pipe
|
| 234 |
|
| 235 |
+
def swap_t2i_unet(pipe, target_unet_path):
|
| 236 |
+
global current_t2i_unet_path
|
| 237 |
+
target_unet_path = os.path.normpath(target_unet_path)
|
| 238 |
+
if current_t2i_unet_path is None or os.path.normpath(current_t2i_unet_path) != target_unet_path:
|
| 239 |
+
print(f"🔄 Swapping T2I UNet to: {target_unet_path}")
|
| 240 |
+
try:
|
| 241 |
+
new_unet = UNet2DModel.from_pretrained(target_unet_path, subfolder="unet").to(DEVICE)
|
| 242 |
+
pipe.unet = new_unet
|
| 243 |
+
current_t2i_unet_path = target_unet_path
|
| 244 |
+
print("✅ UNet swapped successfully.")
|
| 245 |
+
except Exception as e:
|
| 246 |
+
raise gr.Error(f"Failed to load UNet from {target_unet_path}. Error: {e}")
|
| 247 |
+
return pipe
|
| 248 |
+
|
| 249 |
+
# def generate_t2i(prompt, num_inference_steps):
|
| 250 |
+
# if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
|
| 251 |
+
# print(f"\nTask started... | Prompt: '{prompt}' | Steps: {num_inference_steps}")
|
| 252 |
+
# image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
|
| 253 |
+
# generated_image = numpy_to_pil(image_np)
|
| 254 |
+
# print("✓ Image generated")
|
| 255 |
+
# if SAVE_EXAMPLES:
|
| 256 |
+
# example_filepath = os.path.join(T2I_EXAMPLE_IMG_DIR, sanitize_prompt_for_filename(prompt))0
|
| 257 |
+
# if not os.path.exists(example_filepath):
|
| 258 |
+
# generated_image.save(example_filepath); print(f"✓ New T2I example saved: {example_filepath}")
|
| 259 |
+
# return generated_image
|
| 260 |
def generate_t2i(prompt, num_inference_steps):
|
| 261 |
+
global t2i_pipe
|
| 262 |
if t2i_pipe is None: raise gr.Error("Text-to-Image model is not loaded.")
|
| 263 |
+
target_model_path = PROMPT_TO_MODEL_MAP.get(prompt)
|
| 264 |
+
if target_model_path:
|
| 265 |
+
t2i_pipe = swap_t2i_unet(t2i_pipe, target_model_path)
|
| 266 |
+
else:
|
| 267 |
+
print(f"⚠️ Warning: No specific model mapped for '{prompt}', using current weights.")
|
| 268 |
+
print(f"\n🚀 Task started... | Prompt: '{prompt}' | Model: {current_t2i_unet_path}")
|
| 269 |
image_np = t2i_pipe(prompt.lower(), generator=None, num_inference_steps=int(num_inference_steps), output_type="np").images
|
| 270 |
generated_image = numpy_to_pil(image_np)
|
| 271 |
print("✓ Image generated")
|
|
|
|
| 623 |
# sr_prompt_input = gr.Textbox(label="Prompt (e.g., structure name)", value="CCPs of COS-7")
|
| 624 |
sr_prompt_input = gr.Textbox(
|
| 625 |
label="Prompt",
|
| 626 |
+
value="F-actin of COS-7", # 初始值根据你的默认选择设定
|
| 627 |
interactive=False
|
| 628 |
)
|
| 629 |
sr_steps_slider = gr.Slider(minimum=5, maximum=50, step=1, value=10, label="Inference Steps")
|