|
|
|
|
|
""" |
|
|
Trouter-Imagine-1 Comprehensive Inference Script |
|
|
Apache 2.0 License |
|
|
|
|
|
This script provides a complete interface for generating images using the |
|
|
OpenTrouter/Trouter-Imagine-1 model with extensive customization options, |
|
|
batch processing, and advanced features. |
|
|
|
|
|
Usage: |
|
|
python inference.py --prompt "your prompt here" --output output.png |
|
|
python inference.py --batch prompts.txt --output_dir ./outputs/ |
|
|
python inference.py --interactive |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from diffusers import ( |
|
|
StableDiffusionPipeline, |
|
|
DPMSolverMultistepScheduler, |
|
|
EulerAncestralDiscreteScheduler, |
|
|
DDIMScheduler, |
|
|
PNDMScheduler |
|
|
) |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Optional, Tuple |
|
|
import time |
|
|
from datetime import datetime |
|
|
import random |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler('trouter_inference.log'), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class TrouterImageGenerator: |
|
|
""" |
|
|
Comprehensive image generation class for Trouter-Imagine-1 model |
|
|
|
|
|
Features: |
|
|
- Multiple scheduler support |
|
|
- Batch processing |
|
|
- Memory optimization |
|
|
- Advanced parameter control |
|
|
- Image post-processing |
|
|
- Metadata embedding |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_id: str = "OpenTrouter/Trouter-Imagine-1", |
|
|
device: str = "cuda", |
|
|
dtype: torch.dtype = torch.float16, |
|
|
enable_memory_optimization: bool = True |
|
|
): |
|
|
""" |
|
|
Initialize the image generator |
|
|
|
|
|
Args: |
|
|
model_id: HuggingFace model identifier |
|
|
device: Device to run inference on (cuda, cpu, mps) |
|
|
dtype: Data type for model weights |
|
|
enable_memory_optimization: Enable VRAM optimizations |
|
|
""" |
|
|
self.model_id = model_id |
|
|
self.device = device |
|
|
self.dtype = dtype |
|
|
self.pipe = None |
|
|
self.generation_count = 0 |
|
|
|
|
|
logger.info(f"Initializing Trouter-Imagine-1 on {device}") |
|
|
self._load_model(enable_memory_optimization) |
|
|
|
|
|
def _load_model(self, enable_optimization: bool): |
|
|
"""Load the diffusion model pipeline""" |
|
|
try: |
|
|
self.pipe = StableDiffusionPipeline.from_pretrained( |
|
|
self.model_id, |
|
|
torch_dtype=self.dtype, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False |
|
|
) |
|
|
|
|
|
|
|
|
if self.device == "mps": |
|
|
self.pipe = self.pipe.to("mps") |
|
|
|
|
|
self.pipe.enable_attention_slicing() |
|
|
elif self.device == "cuda": |
|
|
self.pipe = self.pipe.to("cuda") |
|
|
|
|
|
if enable_optimization: |
|
|
|
|
|
try: |
|
|
self.pipe.enable_attention_slicing() |
|
|
self.pipe.enable_vae_slicing() |
|
|
logger.info("Memory optimizations enabled") |
|
|
except Exception as e: |
|
|
logger.warning(f"Some optimizations failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
|
logger.info("xformers memory efficient attention enabled") |
|
|
except Exception: |
|
|
logger.info("xformers not available, using standard attention") |
|
|
else: |
|
|
self.pipe = self.pipe.to("cpu") |
|
|
logger.warning("Running on CPU - inference will be slow") |
|
|
|
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise |
|
|
|
|
|
def set_scheduler(self, scheduler_type: str): |
|
|
""" |
|
|
Change the diffusion scheduler |
|
|
|
|
|
Args: |
|
|
scheduler_type: Type of scheduler (dpm, euler, ddim, pndm) |
|
|
""" |
|
|
schedulers = { |
|
|
"dpm": DPMSolverMultistepScheduler, |
|
|
"euler": EulerAncestralDiscreteScheduler, |
|
|
"ddim": DDIMScheduler, |
|
|
"pndm": PNDMScheduler |
|
|
} |
|
|
|
|
|
if scheduler_type.lower() not in schedulers: |
|
|
logger.warning(f"Unknown scheduler {scheduler_type}, using default") |
|
|
return |
|
|
|
|
|
scheduler_class = schedulers[scheduler_type.lower()] |
|
|
self.pipe.scheduler = scheduler_class.from_config( |
|
|
self.pipe.scheduler.config |
|
|
) |
|
|
logger.info(f"Scheduler set to {scheduler_type}") |
|
|
|
|
|
def generate_image( |
|
|
self, |
|
|
prompt: str, |
|
|
negative_prompt: str = "", |
|
|
width: int = 512, |
|
|
height: int = 512, |
|
|
num_inference_steps: int = 30, |
|
|
guidance_scale: float = 7.5, |
|
|
seed: Optional[int] = None, |
|
|
num_images: int = 1, |
|
|
callback_steps: int = 5 |
|
|
) -> Tuple[List[Image.Image], Dict]: |
|
|
""" |
|
|
Generate images from text prompt |
|
|
|
|
|
Args: |
|
|
prompt: Text description of desired image |
|
|
negative_prompt: What to avoid in generation |
|
|
width: Image width (must be multiple of 8) |
|
|
height: Image height (must be multiple of 8) |
|
|
num_inference_steps: Number of denoising steps |
|
|
guidance_scale: Prompt adherence strength |
|
|
seed: Random seed for reproducibility |
|
|
num_images: Number of images to generate |
|
|
callback_steps: Steps between progress callbacks |
|
|
|
|
|
Returns: |
|
|
Tuple of (generated images list, metadata dict) |
|
|
""" |
|
|
|
|
|
if width % 8 != 0 or height % 8 != 0: |
|
|
logger.warning("Width and height must be multiples of 8, rounding...") |
|
|
width = (width // 8) * 8 |
|
|
height = (height // 8) * 8 |
|
|
|
|
|
|
|
|
generator = None |
|
|
if seed is not None: |
|
|
generator = torch.Generator(device=self.device).manual_seed(seed) |
|
|
logger.info(f"Using seed: {seed}") |
|
|
else: |
|
|
seed = random.randint(0, 2**32 - 1) |
|
|
generator = torch.Generator(device=self.device).manual_seed(seed) |
|
|
logger.info(f"Generated random seed: {seed}") |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"prompt": prompt, |
|
|
"negative_prompt": negative_prompt, |
|
|
"width": width, |
|
|
"height": height, |
|
|
"num_inference_steps": num_inference_steps, |
|
|
"guidance_scale": guidance_scale, |
|
|
"seed": seed, |
|
|
"model": self.model_id, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
logger.info(f"Generating {num_images} image(s)...") |
|
|
logger.info(f"Prompt: {prompt[:100]}...") |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
|
|
|
with torch.autocast(self.device) if self.device == "cuda" else torch.no_grad(): |
|
|
output = self.pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt if negative_prompt else None, |
|
|
width=width, |
|
|
height=height, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
num_images_per_prompt=num_images, |
|
|
generator=generator |
|
|
) |
|
|
|
|
|
images = output.images |
|
|
generation_time = time.time() - start_time |
|
|
|
|
|
metadata["generation_time"] = generation_time |
|
|
metadata["images_generated"] = len(images) |
|
|
|
|
|
self.generation_count += len(images) |
|
|
|
|
|
logger.info(f"Generation complete in {generation_time:.2f}s") |
|
|
logger.info(f"Total images generated this session: {self.generation_count}") |
|
|
|
|
|
return images, metadata |
|
|
|
|
|
except torch.cuda.OutOfMemoryError: |
|
|
logger.error("CUDA out of memory! Try reducing resolution or batch size") |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Generation failed: {e}") |
|
|
raise |
|
|
|
|
|
def generate_batch( |
|
|
self, |
|
|
prompts: List[str], |
|
|
output_dir: str = "./outputs", |
|
|
**generation_kwargs |
|
|
) -> List[Tuple[Image.Image, Dict]]: |
|
|
""" |
|
|
Generate multiple images from a list of prompts |
|
|
|
|
|
Args: |
|
|
prompts: List of text prompts |
|
|
output_dir: Directory to save generated images |
|
|
**generation_kwargs: Additional arguments passed to generate_image |
|
|
|
|
|
Returns: |
|
|
List of (image, metadata) tuples |
|
|
""" |
|
|
results = [] |
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logger.info(f"Starting batch generation of {len(prompts)} prompts") |
|
|
|
|
|
for i, prompt in enumerate(tqdm(prompts, desc="Generating images")): |
|
|
try: |
|
|
images, metadata = self.generate_image(prompt=prompt, **generation_kwargs) |
|
|
|
|
|
for j, image in enumerate(images): |
|
|
|
|
|
filename = f"batch_{i:04d}_{j:02d}.png" |
|
|
filepath = output_path / filename |
|
|
|
|
|
|
|
|
self._save_image_with_metadata(image, filepath, metadata) |
|
|
|
|
|
results.append((image, metadata)) |
|
|
logger.info(f"Saved: {filepath}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to generate image {i}: {e}") |
|
|
continue |
|
|
|
|
|
logger.info(f"Batch generation complete. {len(results)} images saved to {output_dir}") |
|
|
return results |
|
|
|
|
|
def _save_image_with_metadata( |
|
|
self, |
|
|
image: Image.Image, |
|
|
filepath: Path, |
|
|
metadata: Dict |
|
|
): |
|
|
"""Save image with embedded metadata""" |
|
|
from PIL import PngImagePlugin |
|
|
|
|
|
|
|
|
png_info = PngImagePlugin.PngInfo() |
|
|
|
|
|
|
|
|
for key, value in metadata.items(): |
|
|
png_info.add_text(key, str(value)) |
|
|
|
|
|
|
|
|
image.save(filepath, "PNG", pnginfo=png_info) |
|
|
|
|
|
def generate_variations( |
|
|
self, |
|
|
prompt: str, |
|
|
num_variations: int = 4, |
|
|
output_dir: str = "./variations", |
|
|
**base_kwargs |
|
|
) -> List[Tuple[Image.Image, Dict]]: |
|
|
""" |
|
|
Generate variations by using different seeds |
|
|
|
|
|
Args: |
|
|
prompt: Text prompt |
|
|
num_variations: Number of variations to create |
|
|
output_dir: Output directory |
|
|
**base_kwargs: Base generation parameters |
|
|
|
|
|
Returns: |
|
|
List of (image, metadata) tuples |
|
|
""" |
|
|
results = [] |
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logger.info(f"Generating {num_variations} variations of prompt") |
|
|
|
|
|
for i in range(num_variations): |
|
|
seed = random.randint(0, 2**32 - 1) |
|
|
images, metadata = self.generate_image( |
|
|
prompt=prompt, |
|
|
seed=seed, |
|
|
**base_kwargs |
|
|
) |
|
|
|
|
|
for j, image in enumerate(images): |
|
|
filename = f"variation_{i:02d}_{j:02d}_seed_{seed}.png" |
|
|
filepath = output_path / filename |
|
|
self._save_image_with_metadata(image, filepath, metadata) |
|
|
results.append((image, metadata)) |
|
|
logger.info(f"Saved variation: {filepath}") |
|
|
|
|
|
return results |
|
|
|
|
|
def create_grid( |
|
|
self, |
|
|
images: List[Image.Image], |
|
|
rows: int = 2, |
|
|
cols: int = 2, |
|
|
output_path: str = "grid.png" |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Create a grid of images |
|
|
|
|
|
Args: |
|
|
images: List of PIL Images |
|
|
rows: Number of rows |
|
|
cols: Number of columns |
|
|
output_path: Path to save grid |
|
|
|
|
|
Returns: |
|
|
Grid image |
|
|
""" |
|
|
if len(images) < rows * cols: |
|
|
logger.warning(f"Not enough images for {rows}x{cols} grid") |
|
|
|
|
|
|
|
|
w, h = images[0].size |
|
|
|
|
|
|
|
|
grid = Image.new('RGB', (cols * w, rows * h)) |
|
|
|
|
|
for i, img in enumerate(images[:rows * cols]): |
|
|
row = i // cols |
|
|
col = i % cols |
|
|
grid.paste(img, (col * w, row * h)) |
|
|
|
|
|
grid.save(output_path) |
|
|
logger.info(f"Grid saved to {output_path}") |
|
|
return grid |
|
|
|
|
|
def upscale_image( |
|
|
self, |
|
|
image: Image.Image, |
|
|
scale_factor: int = 2, |
|
|
method: str = "lanczos" |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Upscale an image using various interpolation methods |
|
|
|
|
|
Args: |
|
|
image: Input PIL Image |
|
|
scale_factor: Scaling factor |
|
|
method: Interpolation method (lanczos, bicubic, bilinear, nearest) |
|
|
|
|
|
Returns: |
|
|
Upscaled image |
|
|
""" |
|
|
methods = { |
|
|
"lanczos": Image.LANCZOS, |
|
|
"bicubic": Image.BICUBIC, |
|
|
"bilinear": Image.BILINEAR, |
|
|
"nearest": Image.NEAREST |
|
|
} |
|
|
|
|
|
resample = methods.get(method.lower(), Image.LANCZOS) |
|
|
new_size = (image.width * scale_factor, image.height * scale_factor) |
|
|
|
|
|
logger.info(f"Upscaling image from {image.size} to {new_size}") |
|
|
return image.resize(new_size, resample=resample) |
|
|
|
|
|
|
|
|
def load_prompts_from_file(filepath: str) -> List[str]: |
|
|
"""Load prompts from text file (one per line)""" |
|
|
with open(filepath, 'r', encoding='utf-8') as f: |
|
|
prompts = [line.strip() for line in f if line.strip()] |
|
|
return prompts |
|
|
|
|
|
|
|
|
def load_config_from_json(filepath: str) -> Dict: |
|
|
"""Load generation config from JSON file""" |
|
|
with open(filepath, 'r') as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def interactive_mode(generator: TrouterImageGenerator): |
|
|
"""Interactive prompt-based generation mode""" |
|
|
print("\n" + "="*60) |
|
|
print("Trouter-Imagine-1 Interactive Mode") |
|
|
print("="*60) |
|
|
print("Type 'quit' or 'exit' to stop") |
|
|
print("Type 'config' to change generation settings") |
|
|
print("="*60 + "\n") |
|
|
|
|
|
|
|
|
settings = { |
|
|
"width": 512, |
|
|
"height": 512, |
|
|
"steps": 30, |
|
|
"guidance": 7.5, |
|
|
"negative_prompt": "blurry, low quality, distorted", |
|
|
"num_images": 1, |
|
|
"output_dir": "./interactive_outputs" |
|
|
} |
|
|
|
|
|
Path(settings["output_dir"]).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
while True: |
|
|
prompt = input("\nEnter your prompt (or command): ").strip() |
|
|
|
|
|
if prompt.lower() in ['quit', 'exit', 'q']: |
|
|
print("Exiting interactive mode...") |
|
|
break |
|
|
|
|
|
if prompt.lower() == 'config': |
|
|
print("\nCurrent settings:") |
|
|
for key, value in settings.items(): |
|
|
print(f" {key}: {value}") |
|
|
print("\nEnter new values (or press Enter to keep current):") |
|
|
|
|
|
for key in settings: |
|
|
new_val = input(f" {key} [{settings[key]}]: ").strip() |
|
|
if new_val: |
|
|
try: |
|
|
|
|
|
if isinstance(settings[key], int): |
|
|
settings[key] = int(new_val) |
|
|
elif isinstance(settings[key], float): |
|
|
settings[key] = float(new_val) |
|
|
else: |
|
|
settings[key] = new_val |
|
|
except ValueError: |
|
|
print(f"Invalid value for {key}, keeping current") |
|
|
continue |
|
|
|
|
|
if not prompt: |
|
|
print("Please enter a valid prompt") |
|
|
continue |
|
|
|
|
|
try: |
|
|
print(f"\nGenerating with prompt: {prompt}") |
|
|
images, metadata = generator.generate_image( |
|
|
prompt=prompt, |
|
|
negative_prompt=settings["negative_prompt"], |
|
|
width=settings["width"], |
|
|
height=settings["height"], |
|
|
num_inference_steps=settings["steps"], |
|
|
guidance_scale=settings["guidance"], |
|
|
num_images=settings["num_images"] |
|
|
) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
for i, image in enumerate(images): |
|
|
filename = f"{timestamp}_{i:02d}.png" |
|
|
filepath = Path(settings["output_dir"]) / filename |
|
|
generator._save_image_with_metadata(image, filepath, metadata) |
|
|
print(f"Saved: {filepath}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main entry point with CLI argument parsing""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Trouter-Imagine-1 Image Generation Script", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Generate single image |
|
|
python inference.py --prompt "a beautiful sunset" --output sunset.png |
|
|
|
|
|
# Generate with custom parameters |
|
|
python inference.py --prompt "cyberpunk city" --width 768 --height 768 --steps 50 |
|
|
|
|
|
# Batch generation from file |
|
|
python inference.py --batch prompts.txt --output_dir ./batch_outputs/ |
|
|
|
|
|
# Generate variations |
|
|
python inference.py --prompt "mountain landscape" --variations 8 |
|
|
|
|
|
# Interactive mode |
|
|
python inference.py --interactive |
|
|
|
|
|
# Use different scheduler |
|
|
python inference.py --prompt "portrait" --scheduler dpm |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument("--model", type=str, default="OpenTrouter/Trouter-Imagine-1", |
|
|
help="HuggingFace model ID") |
|
|
parser.add_argument("--device", type=str, default="cuda", |
|
|
choices=["cuda", "cpu", "mps"], |
|
|
help="Device to run inference on") |
|
|
parser.add_argument("--dtype", type=str, default="float16", |
|
|
choices=["float16", "float32"], |
|
|
help="Model precision") |
|
|
parser.add_argument("--no-optimization", action="store_true", |
|
|
help="Disable memory optimizations") |
|
|
|
|
|
|
|
|
parser.add_argument("--prompt", type=str, |
|
|
help="Text prompt for generation") |
|
|
parser.add_argument("--negative-prompt", type=str, default="", |
|
|
help="Negative prompt") |
|
|
parser.add_argument("--width", type=int, default=512, |
|
|
help="Image width") |
|
|
parser.add_argument("--height", type=int, default=512, |
|
|
help="Image height") |
|
|
parser.add_argument("--steps", type=int, default=30, |
|
|
help="Number of inference steps") |
|
|
parser.add_argument("--guidance", type=float, default=7.5, |
|
|
help="Guidance scale") |
|
|
parser.add_argument("--seed", type=int, |
|
|
help="Random seed") |
|
|
parser.add_argument("--num-images", type=int, default=1, |
|
|
help="Number of images to generate") |
|
|
parser.add_argument("--scheduler", type=str, |
|
|
choices=["dpm", "euler", "ddim", "pndm"], |
|
|
help="Diffusion scheduler to use") |
|
|
|
|
|
|
|
|
parser.add_argument("--batch", type=str, |
|
|
help="File containing prompts (one per line)") |
|
|
parser.add_argument("--variations", type=int, |
|
|
help="Generate N variations of the prompt") |
|
|
parser.add_argument("--grid", action="store_true", |
|
|
help="Create grid from generated images") |
|
|
parser.add_argument("--grid-rows", type=int, default=2, |
|
|
help="Grid rows") |
|
|
parser.add_argument("--grid-cols", type=int, default=2, |
|
|
help="Grid columns") |
|
|
|
|
|
|
|
|
parser.add_argument("--output", type=str, default="output.png", |
|
|
help="Output filepath") |
|
|
parser.add_argument("--output-dir", type=str, default="./outputs", |
|
|
help="Output directory for batch generation") |
|
|
|
|
|
|
|
|
parser.add_argument("--interactive", action="store_true", |
|
|
help="Enter interactive mode") |
|
|
parser.add_argument("--config", type=str, |
|
|
help="Load config from JSON file") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.config: |
|
|
config = load_config_from_json(args.config) |
|
|
for key, value in config.items(): |
|
|
if hasattr(args, key): |
|
|
setattr(args, key, value) |
|
|
|
|
|
|
|
|
dtype = torch.float16 if args.dtype == "float16" else torch.float32 |
|
|
|
|
|
|
|
|
logger.info("Initializing Trouter-Imagine-1 generator...") |
|
|
generator = TrouterImageGenerator( |
|
|
model_id=args.model, |
|
|
device=args.device, |
|
|
dtype=dtype, |
|
|
enable_memory_optimization=not args.no_optimization |
|
|
) |
|
|
|
|
|
|
|
|
if args.scheduler: |
|
|
generator.set_scheduler(args.scheduler) |
|
|
|
|
|
|
|
|
if args.interactive: |
|
|
interactive_mode(generator) |
|
|
return |
|
|
|
|
|
|
|
|
gen_kwargs = { |
|
|
"width": args.width, |
|
|
"height": args.height, |
|
|
"num_inference_steps": args.steps, |
|
|
"guidance_scale": args.guidance, |
|
|
"negative_prompt": args.negative_prompt, |
|
|
"num_images": args.num_images |
|
|
} |
|
|
|
|
|
if args.seed is not None: |
|
|
gen_kwargs["seed"] = args.seed |
|
|
|
|
|
|
|
|
if args.batch: |
|
|
prompts = load_prompts_from_file(args.batch) |
|
|
results = generator.generate_batch( |
|
|
prompts=prompts, |
|
|
output_dir=args.output_dir, |
|
|
**gen_kwargs |
|
|
) |
|
|
|
|
|
if args.grid: |
|
|
images = [img for img, _ in results] |
|
|
generator.create_grid( |
|
|
images, |
|
|
rows=args.grid_rows, |
|
|
cols=args.grid_cols, |
|
|
output_path=os.path.join(args.output_dir, "grid.png") |
|
|
) |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
if args.variations and args.prompt: |
|
|
results = generator.generate_variations( |
|
|
prompt=args.prompt, |
|
|
num_variations=args.variations, |
|
|
output_dir=args.output_dir, |
|
|
**gen_kwargs |
|
|
) |
|
|
|
|
|
if args.grid: |
|
|
images = [img for img, _ in results] |
|
|
generator.create_grid( |
|
|
images, |
|
|
rows=args.grid_rows, |
|
|
cols=args.grid_cols, |
|
|
output_path=os.path.join(args.output_dir, "variations_grid.png") |
|
|
) |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
if args.prompt: |
|
|
images, metadata = generator.generate_image( |
|
|
prompt=args.prompt, |
|
|
**gen_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
for i, image in enumerate(images): |
|
|
if len(images) > 1: |
|
|
base, ext = os.path.splitext(args.output) |
|
|
filepath = f"{base}_{i:02d}{ext}" |
|
|
else: |
|
|
filepath = args.output |
|
|
|
|
|
generator._save_image_with_metadata(image, Path(filepath), metadata) |
|
|
logger.info(f"Image saved to: {filepath}") |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
parser.print_help() |
|
|
print("\nError: Please specify --prompt, --batch, --variations, or --interactive") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |