File size: 2,551 Bytes
b171568 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
import re
import torch
import torch.distributed as dist
from pathlib import Path
from diffusers import FluxPipeline
from diffusers import FluxTransformer2DModel
from torch.utils.data import Dataset, DistributedSampler
class PromptDataset(Dataset):
def __init__(self, file_path):
with open(file_path, 'r') as f:
self.prompts = [line.strip() for line in f if line.strip()]
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
return self.prompts[idx]
def sanitize_filename(text, max_length=200):
sanitized = re.sub(r'[\\/:*?"<>|]', '_', text)
return sanitized[:max_length].rstrip() or "untitled"
def distributed_setup():
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
return rank, local_rank, world_size
def main():
rank, local_rank, world_size = distributed_setup()
model_path = "CKPT_PATH"
flux_path = "./ckpt/flux"
transformer = FluxTransformer2DModel.from_pretrained(model_path, use_safetensors=True, torch_dtype=torch.float16).to("cuda")
pipe = FluxPipeline.from_pretrained(flux_path, transformer=None, torch_dtype=torch.float16).to("cuda")
pipe.transformer = transformer
dataset = PromptDataset("scripts/evaluation/prompt_test.txt")
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=False
)
output_dir = Path("IMAGE_SAVE_FOLDER")
output_dir.mkdir(parents=True, exist_ok=True)
for idx in sampler:
prompt = dataset[idx]
try:
generator = torch.Generator(device=f"cuda:{local_rank}")
generator.manual_seed(42 + idx + rank*1000)
image = pipe(
prompt,
guidance_scale=3.5,
height=1024,
width=1024,
num_inference_steps=50,
max_sequence_length=512,
generator=generator,
).images[0]
filename = sanitize_filename(prompt)
save_path = output_dir / f"{filename}.png"
image.save(save_path)
print(f"[Rank {rank}] Generated: {save_path.name}")
except Exception as e:
print(f"[Rank {rank}] Error processing '{prompt[:20]}...': {str(e)}")
dist.destroy_process_group()
if __name__ == "__main__":
main() |