Trouter-Imagine-1 / pipeline.py
Trouter-Library's picture
Create pipeline.py
3fd5fbc verified
#!/usr/bin/env python3
"""
Trouter-Imagine-1 Complete Pipeline
Apache 2.0 License
This file provides a complete, ready-to-use pipeline for text-to-image generation.
It includes all necessary components and can be used immediately for generating images.
This is the MAIN FILE for using the model - simple and powerful.
"""
import torch
from diffusers import (
StableDiffusionPipeline,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
DDIMScheduler
)
from PIL import Image
import os
from typing import List, Optional, Union, Dict
import warnings
import logging
from pathlib import Path
import json
from datetime import datetime
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class TrouterImagePipeline:
"""
Complete ready-to-use pipeline for Trouter-Imagine-1
This is the main class you should use for image generation.
It's simple, powerful, and handles everything automatically.
Example:
>>> pipeline = TrouterImagePipeline()
>>> image = pipeline("a beautiful sunset")
>>> image.save("sunset.png")
"""
# Default base model (you can change this to your custom model once trained)
DEFAULT_MODEL = "runwayml/stable-diffusion-v1-5"
# Can also use these alternatives:
# "stabilityai/stable-diffusion-2-1"
# "stabilityai/stable-diffusion-xl-base-1.0"
def __init__(
self,
model_id: Optional[str] = None,
device: Optional[str] = None,
torch_dtype: torch.dtype = torch.float16,
use_safetensors: bool = True,
enable_optimizations: bool = True
):
"""
Initialize the Trouter-Imagine-1 pipeline
Args:
model_id: Model to use (defaults to Stable Diffusion 1.5)
device: Device to use (auto-detected if None)
torch_dtype: Model precision (float16 for speed, float32 for quality)
use_safetensors: Use safetensors format (recommended)
enable_optimizations: Enable memory optimizations
"""
# Auto-detect device
if device is None:
if torch.cuda.is_available():
device = "cuda"
logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
device = "mps"
logger.info("Using Apple Silicon (MPS)")
else:
device = "cpu"
logger.warning("No GPU detected, using CPU (will be slow)")
self.device = device
self.dtype = torch_dtype
self.model_id = model_id or self.DEFAULT_MODEL
logger.info(f"Initializing Trouter-Imagine-1 Pipeline")
logger.info(f"Model: {self.model_id}")
logger.info(f"Device: {self.device}")
logger.info(f"Precision: {self.dtype}")
# Load pipeline
self._load_pipeline(use_safetensors)
# Apply optimizations
if enable_optimizations:
self._optimize()
# Default settings
self.default_negative = "blurry, low quality, distorted, deformed, ugly, bad anatomy, watermark, signature, text"
logger.info("✓ Pipeline ready!")
def _load_pipeline(self, use_safetensors: bool):
"""Load the diffusion pipeline"""
try:
self.pipe = StableDiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=self.dtype,
use_safetensors=use_safetensors,
safety_checker=None, # Disable for flexibility
requires_safety_checker=False
)
# Move to device
self.pipe = self.pipe.to(self.device)
# Set better scheduler by default
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
logger.info("✓ Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def _optimize(self):
"""Apply memory and speed optimizations"""
logger.info("Applying optimizations...")
try:
# Memory optimizations
self.pipe.enable_attention_slicing()
self.pipe.enable_vae_slicing()
logger.info(" ✓ Memory optimizations enabled")
except Exception as e:
logger.warning(f" ⚠ Memory optimization failed: {e}")
# Try xformers for even better performance
try:
self.pipe.enable_xformers_memory_efficient_attention()
logger.info(" ✓ xformers enabled (faster generation)")
except Exception:
logger.info(" ℹ xformers not available (this is fine)")
# Model CPU offload for very limited VRAM
# Uncomment if you have < 6GB VRAM:
# self.pipe.enable_model_cpu_offload()
def __call__(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
width: int = 512,
height: int = 512,
num_inference_steps: int = 30,
guidance_scale: float = 7.5,
num_images: int = 1,
seed: Optional[int] = None,
return_dict: bool = False
) -> Union[Image.Image, List[Image.Image], Dict]:
"""
Generate images from text prompt
Args:
prompt: Text description or list of descriptions
negative_prompt: What to avoid (uses default if None)
width: Image width (must be multiple of 8)
height: Image height (must be multiple of 8)
num_inference_steps: Quality (20=fast, 30=balanced, 50=quality)
guidance_scale: Prompt adherence (7.5 is good default)
num_images: Number of images to generate
seed: Random seed for reproducibility
return_dict: Return dictionary with metadata
Returns:
Generated image(s) or dictionary with images and metadata
"""
# Use default negative prompt if none provided
if negative_prompt is None:
negative_prompt = self.default_negative
# Set seed if provided
generator = None
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
# Validate dimensions
if width % 8 != 0:
width = (width // 8) * 8
logger.warning(f"Width adjusted to {width} (must be multiple of 8)")
if height % 8 != 0:
height = (height // 8) * 8
logger.warning(f"Height adjusted to {height} (must be multiple of 8)")
# Generate
logger.info(f"Generating: {prompt[:100]}...")
try:
with torch.autocast(self.device) if self.device == "cuda" else torch.no_grad():
output = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images,
generator=generator
)
images = output.images
logger.info(f"✓ Generated {len(images)} image(s)")
if return_dict:
return {
'images': images,
'prompt': prompt,
'negative_prompt': negative_prompt,
'width': width,
'height': height,
'steps': num_inference_steps,
'guidance': guidance_scale,
'seed': seed
}
return images[0] if len(images) == 1 else images
except torch.cuda.OutOfMemoryError:
logger.error("GPU out of memory! Try:")
logger.error(" 1. Reduce resolution (e.g., 512x512 instead of 1024x1024)")
logger.error(" 2. Reduce num_images")
logger.error(" 3. Close other applications")
raise
except Exception as e:
logger.error(f"Generation failed: {e}")
raise
def generate_batch(
self,
prompts: List[str],
output_dir: str = "./outputs",
**kwargs
) -> List[Image.Image]:
"""
Generate multiple images from different prompts
Args:
prompts: List of text prompts
output_dir: Directory to save images
**kwargs: Additional generation parameters
Returns:
List of generated images
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
images = []
logger.info(f"Generating batch of {len(prompts)} images...")
for i, prompt in enumerate(prompts):
logger.info(f" [{i+1}/{len(prompts)}] {prompt[:50]}...")
image = self(prompt, **kwargs)
images.append(image)
# Save
filename = output_path / f"image_{i:04d}.png"
image.save(filename)
logger.info(f" ✓ Saved to {filename}")
logger.info(f"✓ Batch complete! {len(images)} images in {output_dir}")
return images
def generate_variations(
self,
prompt: str,
num_variations: int = 4,
**kwargs
) -> List[Image.Image]:
"""
Generate variations of the same prompt (different seeds)
Args:
prompt: Text prompt
num_variations: Number of variations
**kwargs: Additional generation parameters
Returns:
List of image variations
"""
logger.info(f"Generating {num_variations} variations...")
images = []
for i in range(num_variations):
seed = torch.randint(0, 2**32, (1,)).item()
image = self(prompt, seed=seed, **kwargs)
images.append(image)
logger.info(f" ✓ Variation {i+1}/{num_variations}")
return images
def set_scheduler(self, scheduler_name: str):
"""
Change the diffusion scheduler
Args:
scheduler_name: 'dpm' (fast), 'euler' (creative), 'ddim' (stable)
"""
schedulers = {
'dpm': DPMSolverMultistepScheduler,
'euler': EulerAncestralDiscreteScheduler,
'ddim': DDIMScheduler,
}
if scheduler_name.lower() not in schedulers:
logger.warning(f"Unknown scheduler: {scheduler_name}")
return
scheduler_class = schedulers[scheduler_name.lower()]
self.pipe.scheduler = scheduler_class.from_config(
self.pipe.scheduler.config
)
logger.info(f"✓ Scheduler changed to {scheduler_name}")
def save_pipeline(self, save_path: str):
"""Save the complete pipeline"""
self.pipe.save_pretrained(save_path)
logger.info(f"✓ Pipeline saved to {save_path}")
def get_config(self) -> Dict:
"""Get current pipeline configuration"""
return {
'model_id': self.model_id,
'device': str(self.device),
'dtype': str(self.dtype),
'scheduler': self.pipe.scheduler.__class__.__name__,
'default_negative_prompt': self.default_negative
}
# ============================================================================
# CONVENIENCE FUNCTIONS
# ============================================================================
def quick_generate(
prompt: str,
output_path: str = "output.png",
quality: str = "balanced",
**kwargs
) -> Image.Image:
"""
Quick one-line image generation
Args:
prompt: What to generate
output_path: Where to save
quality: 'draft' (fast), 'balanced', 'high', 'ultra'
**kwargs: Additional parameters
Returns:
Generated image
Example:
>>> quick_generate("a cat in a hat", "cat.png")
"""
quality_presets = {
'draft': {'num_inference_steps': 15, 'width': 512, 'height': 512},
'balanced': {'num_inference_steps': 30, 'width': 512, 'height': 512},
'high': {'num_inference_steps': 40, 'width': 768, 'height': 768},
'ultra': {'num_inference_steps': 50, 'width': 1024, 'height': 1024}
}
settings = quality_presets.get(quality, quality_presets['balanced'])
settings.update(kwargs)
pipeline = TrouterImagePipeline()
image = pipeline(prompt, **settings)
image.save(output_path)
logger.info(f"✓ Image saved to {output_path}")
return image
def batch_from_file(
prompts_file: str,
output_dir: str = "./outputs",
**kwargs
) -> List[Image.Image]:
"""
Generate images from prompts in a text file
Args:
prompts_file: Text file with one prompt per line
output_dir: Where to save images
**kwargs: Generation parameters
Returns:
List of generated images
"""
with open(prompts_file, 'r') as f:
prompts = [line.strip() for line in f if line.strip()]
pipeline = TrouterImagePipeline()
return pipeline.generate_batch(prompts, output_dir, **kwargs)
# ============================================================================
# PRESETS AND STYLES
# ============================================================================
STYLE_PRESETS = {
'photorealistic': {
'prompt_suffix': ', professional photography, photorealistic, 4k, highly detailed',
'negative_prompt': 'cartoon, anime, painting, illustration, low quality, blurry',
'guidance_scale': 8.5
},
'artistic': {
'prompt_suffix': ', digital art, concept art, detailed illustration',
'negative_prompt': 'photograph, realistic, blurry, low quality',
'guidance_scale': 7.0
},
'anime': {
'prompt_suffix': ', anime style, manga, cel shaded, vibrant colors',
'negative_prompt': 'realistic, 3d, photograph, blurry, low quality',
'guidance_scale': 7.5
},
'oil_painting': {
'prompt_suffix': ', oil painting, painterly, artistic, brushstrokes',
'negative_prompt': 'photograph, digital, 3d render, blurry',
'guidance_scale': 7.5
},
'cinematic': {
'prompt_suffix': ', cinematic lighting, film still, dramatic, movie scene',
'negative_prompt': 'amateur, low quality, poor lighting, blurry',
'guidance_scale': 8.0
}
}
def generate_with_style(
prompt: str,
style: str = 'photorealistic',
output_path: str = "styled_output.png",
**kwargs
) -> Image.Image:
"""
Generate image with predefined style preset
Args:
prompt: Base prompt
style: Style preset name
output_path: Where to save
**kwargs: Additional parameters
Returns:
Generated image
"""
if style not in STYLE_PRESETS:
logger.warning(f"Unknown style: {style}, using photorealistic")
style = 'photorealistic'
preset = STYLE_PRESETS[style]
# Apply style
full_prompt = prompt + preset['prompt_suffix']
kwargs['negative_prompt'] = preset['negative_prompt']
kwargs['guidance_scale'] = preset['guidance_scale']
pipeline = TrouterImagePipeline()
image = pipeline(full_prompt, **kwargs)
image.save(output_path)
logger.info(f"✓ {style.title()} style image saved to {output_path}")
return image
# ============================================================================
# MAIN - COMMAND LINE INTERFACE
# ============================================================================
def main():
"""Simple command line interface"""
import argparse
parser = argparse.ArgumentParser(description="Trouter-Imagine-1 Image Generator")
parser.add_argument("prompt", type=str, help="Text prompt for generation")
parser.add_argument("--output", "-o", type=str, default="output.png",
help="Output file path")
parser.add_argument("--quality", "-q", type=str, default="balanced",
choices=['draft', 'balanced', 'high', 'ultra'],
help="Quality preset")
parser.add_argument("--style", "-s", type=str,
choices=list(STYLE_PRESETS.keys()),
help="Style preset")
parser.add_argument("--seed", type=int, help="Random seed")
parser.add_argument("--width", type=int, default=512, help="Image width")
parser.add_argument("--height", type=int, default=512, help="Image height")
parser.add_argument("--steps", type=int, default=30, help="Inference steps")
parser.add_argument("--guidance", type=float, default=7.5, help="Guidance scale")
parser.add_argument("--negative", type=str, help="Negative prompt")
args = parser.parse_args()
kwargs = {
'width': args.width,
'height': args.height,
'num_inference_steps': args.steps,
'guidance_scale': args.guidance,
'seed': args.seed
}
if args.negative:
kwargs['negative_prompt'] = args.negative
if args.style:
generate_with_style(args.prompt, args.style, args.output, **kwargs)
else:
quick_generate(args.prompt, args.output, args.quality, **kwargs)
if __name__ == "__main__":
print("="*70)
print("TROUTER-IMAGINE-1 IMAGE GENERATION PIPELINE")
print("Apache 2.0 License")
print("="*70)
print()
print("Quick Start Examples:")
print()
print(" # Python:")
print(" from pipeline import TrouterImagePipeline")
print(" pipeline = TrouterImagePipeline()")
print(" image = pipeline('a beautiful sunset over mountains')")
print(" image.save('sunset.png')")
print()
print(" # Command line:")
print(" python pipeline.py 'a cat in a hat' --output cat.png")
print(" python pipeline.py 'portrait' --style photorealistic --quality high")
print()
print("="*70)
print()
# Run CLI if arguments provided
import sys
if len(sys.argv) > 1:
main()