import os import re import torch import torch.distributed as dist from pathlib import Path from diffusers import FluxPipeline from diffusers import FluxTransformer2DModel from torch.utils.data import Dataset, DistributedSampler class PromptDataset(Dataset): def __init__(self, file_path): with open(file_path, 'r') as f: self.prompts = [line.strip() for line in f if line.strip()] def __len__(self): return len(self.prompts) def __getitem__(self, idx): return self.prompts[idx] def sanitize_filename(text, max_length=200): sanitized = re.sub(r'[\\/:*?"<>|]', '_', text) return sanitized[:max_length].rstrip() or "untitled" def distributed_setup(): rank = int(os.environ['RANK']) local_rank = int(os.environ['LOCAL_RANK']) world_size = int(os.environ['WORLD_SIZE']) dist.init_process_group(backend="nccl") torch.cuda.set_device(local_rank) return rank, local_rank, world_size def main(): rank, local_rank, world_size = distributed_setup() model_path = "CKPT_PATH" flux_path = "./ckpt/flux" transformer = FluxTransformer2DModel.from_pretrained(model_path, use_safetensors=True, torch_dtype=torch.float16).to("cuda") pipe = FluxPipeline.from_pretrained(flux_path, transformer=None, torch_dtype=torch.float16).to("cuda") pipe.transformer = transformer dataset = PromptDataset("scripts/evaluation/prompt_test.txt") sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=False ) output_dir = Path("IMAGE_SAVE_FOLDER") output_dir.mkdir(parents=True, exist_ok=True) for idx in sampler: prompt = dataset[idx] try: generator = torch.Generator(device=f"cuda:{local_rank}") generator.manual_seed(42 + idx + rank*1000) image = pipe( prompt, guidance_scale=3.5, height=1024, width=1024, num_inference_steps=50, max_sequence_length=512, generator=generator, ).images[0] filename = sanitize_filename(prompt) save_path = output_dir / f"{filename}.png" image.save(save_path) print(f"[Rank {rank}] Generated: {save_path.name}") except Exception as e: print(f"[Rank {rank}] Error processing '{prompt[:20]}...': {str(e)}") dist.destroy_process_group() if __name__ == "__main__": main()