|
|
| import torch |
| import sys |
| import os |
| import numpy as np |
|
|
| |
| sys.path.append(os.getcwd()) |
|
|
| from diffusers import DiffusionPipeline |
|
|
| def test_pipeline(): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
|
|
| |
| hf_token = "" |
|
|
| repo_id = "xiangjx/MuPaD-512" |
| |
| |
| print(f"Loading pipeline from {repo_id}...") |
| try: |
| from huggingface_hub import snapshot_download |
| print("Downloading repository snapshot...") |
| snapshot_path = snapshot_download( |
| repo_id=repo_id, |
| repo_type="model", |
| token=hf_token, |
| allow_patterns=["*"] |
| ) |
| print(f"Snapshot downloaded to: {snapshot_path}") |
|
|
| |
| sys.path.insert(0, snapshot_path) |
| |
| |
| |
| from pipeline import SiTPipeline |
| |
| print("Initializing SiTPipeline from snapshot...") |
| pipeline = SiTPipeline.from_pretrained( |
| snapshot_path, |
| ) |
| pipeline.to(device) |
| print("Pipeline loaded successfully from Hugging Face snapshot!") |
| except Exception as e: |
| print(f"Failed to load pipeline: {e}") |
| return |
|
|
| |
| print("Running Image-to-Image Generation:") |
| try: |
| from PIL import Image |
| |
| test_img_path = os.path.join(snapshot_path, "test_image.png") |
| if not os.path.exists(test_img_path): |
| print(f"Test image not found at {test_img_path}") |
| |
| from huggingface_hub import hf_hub_download |
| test_img_path = hf_hub_download(repo_id=repo_id, filename="test_image.png", token=hf_token) |
| |
| print(f"Using test image: {test_img_path}") |
| raw_image = Image.open(test_img_path).convert("RGB") |
| |
| output_i2i = pipeline( |
| image=raw_image, |
| modality="image", |
| num_images_per_prompt=5, |
| num_inference_steps=250, |
| guidance_scale=2.5, |
| guidance_high=0.75, |
| guidance_low=0.0, |
| mode="sde", |
| path_type="linear", |
| seed=42 |
| ) |
| for i, img in enumerate(output_i2i["images"]): |
| img.save(f"raw_image2image_{i}.png") |
| print(f"Saved {len(output_i2i['images'])} Raw I2I images.") |
| except Exception as e: |
| print(f"Raw I2I Failed: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| |
| print("Running Text-to-Image Generation:") |
| try: |
| prompt = "lung adenocarcinoma" |
| output_t2i = pipeline( |
| prompt=prompt, |
| modality="text", |
| num_images_per_prompt=5, |
| num_inference_steps=250, |
| guidance_scale=2.5, |
| guidance_high=0.75, |
| guidance_low=0.0, |
| mode="sde", |
| path_type="linear", |
| seed=42 |
| ) |
| for i, img in enumerate(output_t2i["images"]): |
| img.save(f"raw_text2image_{i}.png") |
| print(f"Saved {len(output_t2i['images'])} Raw T2I images.") |
| except Exception as e: |
| print(f"Raw T2I Failed: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
|
|
| if __name__ == "__main__": |
| test_pipeline() |
|
|