MuPaD-512 / demo.py
xiangjx's picture
Update demo.py
ad9261c verified
import torch
import sys
import os
import numpy as np
# Add current directory to path so we can import modules
sys.path.append(os.getcwd())
from diffusers import DiffusionPipeline
def test_pipeline():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# TODO: Replace this with your actual Hugging Face Write/Read token
hf_token = ""
repo_id = "xiangjx/MuPaD-512"
# 1. Load Diffusion Pipeline
print(f"Loading pipeline from {repo_id}...")
try:
from huggingface_hub import snapshot_download
print("Downloading repository snapshot...")
snapshot_path = snapshot_download(
repo_id=repo_id,
repo_type="model",
token=hf_token,
allow_patterns=["*"] # Download everything including code and weights
)
print(f"Snapshot downloaded to: {snapshot_path}")
# Add snapshot path to sys.path to allow importing pipeline and musk
sys.path.insert(0, snapshot_path)
# Import the class dynamically/directly
# This mirrors 'trust_remote_code' behavior but with explicit path control
from pipeline import SiTPipeline
print("Initializing SiTPipeline from snapshot...")
pipeline = SiTPipeline.from_pretrained(
snapshot_path,
)
pipeline.to(device)
print("Pipeline loaded successfully from Hugging Face snapshot!")
except Exception as e:
print(f"Failed to load pipeline: {e}")
return
# I2I with Raw Image
print("Running Image-to-Image Generation:")
try:
from PIL import Image
test_img_path = os.path.join(snapshot_path, "test_image.png")
if not os.path.exists(test_img_path):
print(f"Test image not found at {test_img_path}")
# try downloading specifically if missing (though snapshot should have it)
from huggingface_hub import hf_hub_download
test_img_path = hf_hub_download(repo_id=repo_id, filename="test_image.png", token=hf_token)
print(f"Using test image: {test_img_path}")
raw_image = Image.open(test_img_path).convert("RGB")
output_i2i = pipeline(
image=raw_image, # Pass PIL image directly
modality="image",
num_images_per_prompt=5,
num_inference_steps=250,
guidance_scale=2.5,
guidance_high=0.75,
guidance_low=0.0,
mode="sde",
path_type="linear",
seed=42
)
for i, img in enumerate(output_i2i["images"]):
img.save(f"raw_image2image_{i}.png")
print(f"Saved {len(output_i2i['images'])} Raw I2I images.")
except Exception as e:
print(f"Raw I2I Failed: {e}")
import traceback
traceback.print_exc()
# T2I with Raw Text
print("Running Text-to-Image Generation:")
try:
prompt = "lung adenocarcinoma"
output_t2i = pipeline(
prompt=prompt,
modality="text",
num_images_per_prompt=5,
num_inference_steps=250,
guidance_scale=2.5,
guidance_high=0.75,
guidance_low=0.0,
mode="sde",
path_type="linear",
seed=42
)
for i, img in enumerate(output_t2i["images"]):
img.save(f"raw_text2image_{i}.png")
print(f"Saved {len(output_t2i['images'])} Raw T2I images.")
except Exception as e:
print(f"Raw T2I Failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_pipeline()