| |
| """ |
| Trouter-Imagine-1 Complete Pipeline |
| Apache 2.0 License |
| |
| This file provides a complete, ready-to-use pipeline for text-to-image generation. |
| It includes all necessary components and can be used immediately for generating images. |
| |
| This is the MAIN FILE for using the model - simple and powerful. |
| """ |
|
|
| import torch |
| from diffusers import ( |
| StableDiffusionPipeline, |
| DPMSolverMultistepScheduler, |
| EulerAncestralDiscreteScheduler, |
| DDIMScheduler |
| ) |
| from PIL import Image |
| import os |
| from typing import List, Optional, Union, Dict |
| import warnings |
| import logging |
| from pathlib import Path |
| import json |
| from datetime import datetime |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class TrouterImagePipeline: |
| """ |
| Complete ready-to-use pipeline for Trouter-Imagine-1 |
| |
| This is the main class you should use for image generation. |
| It's simple, powerful, and handles everything automatically. |
| |
| Example: |
| >>> pipeline = TrouterImagePipeline() |
| >>> image = pipeline("a beautiful sunset") |
| >>> image.save("sunset.png") |
| """ |
| |
| |
| DEFAULT_MODEL = "runwayml/stable-diffusion-v1-5" |
| |
| |
| |
| |
| |
| def __init__( |
| self, |
| model_id: Optional[str] = None, |
| device: Optional[str] = None, |
| torch_dtype: torch.dtype = torch.float16, |
| use_safetensors: bool = True, |
| enable_optimizations: bool = True |
| ): |
| """ |
| Initialize the Trouter-Imagine-1 pipeline |
| |
| Args: |
| model_id: Model to use (defaults to Stable Diffusion 1.5) |
| device: Device to use (auto-detected if None) |
| torch_dtype: Model precision (float16 for speed, float32 for quality) |
| use_safetensors: Use safetensors format (recommended) |
| enable_optimizations: Enable memory optimizations |
| """ |
| |
| if device is None: |
| if torch.cuda.is_available(): |
| device = "cuda" |
| logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}") |
| elif torch.backends.mps.is_available(): |
| device = "mps" |
| logger.info("Using Apple Silicon (MPS)") |
| else: |
| device = "cpu" |
| logger.warning("No GPU detected, using CPU (will be slow)") |
| |
| self.device = device |
| self.dtype = torch_dtype |
| self.model_id = model_id or self.DEFAULT_MODEL |
| |
| logger.info(f"Initializing Trouter-Imagine-1 Pipeline") |
| logger.info(f"Model: {self.model_id}") |
| logger.info(f"Device: {self.device}") |
| logger.info(f"Precision: {self.dtype}") |
| |
| |
| self._load_pipeline(use_safetensors) |
| |
| |
| if enable_optimizations: |
| self._optimize() |
| |
| |
| self.default_negative = "blurry, low quality, distorted, deformed, ugly, bad anatomy, watermark, signature, text" |
| |
| logger.info("✓ Pipeline ready!") |
| |
| def _load_pipeline(self, use_safetensors: bool): |
| """Load the diffusion pipeline""" |
| try: |
| self.pipe = StableDiffusionPipeline.from_pretrained( |
| self.model_id, |
| torch_dtype=self.dtype, |
| use_safetensors=use_safetensors, |
| safety_checker=None, |
| requires_safety_checker=False |
| ) |
| |
| |
| self.pipe = self.pipe.to(self.device) |
| |
| |
| self.pipe.scheduler = DPMSolverMultistepScheduler.from_config( |
| self.pipe.scheduler.config |
| ) |
| |
| logger.info("✓ Model loaded successfully") |
| |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}") |
| raise |
| |
| def _optimize(self): |
| """Apply memory and speed optimizations""" |
| logger.info("Applying optimizations...") |
| |
| try: |
| |
| self.pipe.enable_attention_slicing() |
| self.pipe.enable_vae_slicing() |
| logger.info(" ✓ Memory optimizations enabled") |
| except Exception as e: |
| logger.warning(f" ⚠ Memory optimization failed: {e}") |
| |
| |
| try: |
| self.pipe.enable_xformers_memory_efficient_attention() |
| logger.info(" ✓ xformers enabled (faster generation)") |
| except Exception: |
| logger.info(" ℹ xformers not available (this is fine)") |
| |
| |
| |
| |
| |
| def __call__( |
| self, |
| prompt: Union[str, List[str]], |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| width: int = 512, |
| height: int = 512, |
| num_inference_steps: int = 30, |
| guidance_scale: float = 7.5, |
| num_images: int = 1, |
| seed: Optional[int] = None, |
| return_dict: bool = False |
| ) -> Union[Image.Image, List[Image.Image], Dict]: |
| """ |
| Generate images from text prompt |
| |
| Args: |
| prompt: Text description or list of descriptions |
| negative_prompt: What to avoid (uses default if None) |
| width: Image width (must be multiple of 8) |
| height: Image height (must be multiple of 8) |
| num_inference_steps: Quality (20=fast, 30=balanced, 50=quality) |
| guidance_scale: Prompt adherence (7.5 is good default) |
| num_images: Number of images to generate |
| seed: Random seed for reproducibility |
| return_dict: Return dictionary with metadata |
| |
| Returns: |
| Generated image(s) or dictionary with images and metadata |
| """ |
| |
| if negative_prompt is None: |
| negative_prompt = self.default_negative |
| |
| |
| generator = None |
| if seed is not None: |
| generator = torch.Generator(device=self.device).manual_seed(seed) |
| |
| |
| if width % 8 != 0: |
| width = (width // 8) * 8 |
| logger.warning(f"Width adjusted to {width} (must be multiple of 8)") |
| if height % 8 != 0: |
| height = (height // 8) * 8 |
| logger.warning(f"Height adjusted to {height} (must be multiple of 8)") |
| |
| |
| logger.info(f"Generating: {prompt[:100]}...") |
| |
| try: |
| with torch.autocast(self.device) if self.device == "cuda" else torch.no_grad(): |
| output = self.pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| width=width, |
| height=height, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| num_images_per_prompt=num_images, |
| generator=generator |
| ) |
| |
| images = output.images |
| logger.info(f"✓ Generated {len(images)} image(s)") |
| |
| if return_dict: |
| return { |
| 'images': images, |
| 'prompt': prompt, |
| 'negative_prompt': negative_prompt, |
| 'width': width, |
| 'height': height, |
| 'steps': num_inference_steps, |
| 'guidance': guidance_scale, |
| 'seed': seed |
| } |
| |
| return images[0] if len(images) == 1 else images |
| |
| except torch.cuda.OutOfMemoryError: |
| logger.error("GPU out of memory! Try:") |
| logger.error(" 1. Reduce resolution (e.g., 512x512 instead of 1024x1024)") |
| logger.error(" 2. Reduce num_images") |
| logger.error(" 3. Close other applications") |
| raise |
| except Exception as e: |
| logger.error(f"Generation failed: {e}") |
| raise |
| |
| def generate_batch( |
| self, |
| prompts: List[str], |
| output_dir: str = "./outputs", |
| **kwargs |
| ) -> List[Image.Image]: |
| """ |
| Generate multiple images from different prompts |
| |
| Args: |
| prompts: List of text prompts |
| output_dir: Directory to save images |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| List of generated images |
| """ |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
| |
| images = [] |
| logger.info(f"Generating batch of {len(prompts)} images...") |
| |
| for i, prompt in enumerate(prompts): |
| logger.info(f" [{i+1}/{len(prompts)}] {prompt[:50]}...") |
| |
| image = self(prompt, **kwargs) |
| images.append(image) |
| |
| |
| filename = output_path / f"image_{i:04d}.png" |
| image.save(filename) |
| logger.info(f" ✓ Saved to {filename}") |
| |
| logger.info(f"✓ Batch complete! {len(images)} images in {output_dir}") |
| return images |
| |
| def generate_variations( |
| self, |
| prompt: str, |
| num_variations: int = 4, |
| **kwargs |
| ) -> List[Image.Image]: |
| """ |
| Generate variations of the same prompt (different seeds) |
| |
| Args: |
| prompt: Text prompt |
| num_variations: Number of variations |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| List of image variations |
| """ |
| logger.info(f"Generating {num_variations} variations...") |
| |
| images = [] |
| for i in range(num_variations): |
| seed = torch.randint(0, 2**32, (1,)).item() |
| image = self(prompt, seed=seed, **kwargs) |
| images.append(image) |
| logger.info(f" ✓ Variation {i+1}/{num_variations}") |
| |
| return images |
| |
| def set_scheduler(self, scheduler_name: str): |
| """ |
| Change the diffusion scheduler |
| |
| Args: |
| scheduler_name: 'dpm' (fast), 'euler' (creative), 'ddim' (stable) |
| """ |
| schedulers = { |
| 'dpm': DPMSolverMultistepScheduler, |
| 'euler': EulerAncestralDiscreteScheduler, |
| 'ddim': DDIMScheduler, |
| } |
| |
| if scheduler_name.lower() not in schedulers: |
| logger.warning(f"Unknown scheduler: {scheduler_name}") |
| return |
| |
| scheduler_class = schedulers[scheduler_name.lower()] |
| self.pipe.scheduler = scheduler_class.from_config( |
| self.pipe.scheduler.config |
| ) |
| logger.info(f"✓ Scheduler changed to {scheduler_name}") |
| |
| def save_pipeline(self, save_path: str): |
| """Save the complete pipeline""" |
| self.pipe.save_pretrained(save_path) |
| logger.info(f"✓ Pipeline saved to {save_path}") |
| |
| def get_config(self) -> Dict: |
| """Get current pipeline configuration""" |
| return { |
| 'model_id': self.model_id, |
| 'device': str(self.device), |
| 'dtype': str(self.dtype), |
| 'scheduler': self.pipe.scheduler.__class__.__name__, |
| 'default_negative_prompt': self.default_negative |
| } |
|
|
|
|
| |
| |
| |
|
|
| def quick_generate( |
| prompt: str, |
| output_path: str = "output.png", |
| quality: str = "balanced", |
| **kwargs |
| ) -> Image.Image: |
| """ |
| Quick one-line image generation |
| |
| Args: |
| prompt: What to generate |
| output_path: Where to save |
| quality: 'draft' (fast), 'balanced', 'high', 'ultra' |
| **kwargs: Additional parameters |
| |
| Returns: |
| Generated image |
| |
| Example: |
| >>> quick_generate("a cat in a hat", "cat.png") |
| """ |
| quality_presets = { |
| 'draft': {'num_inference_steps': 15, 'width': 512, 'height': 512}, |
| 'balanced': {'num_inference_steps': 30, 'width': 512, 'height': 512}, |
| 'high': {'num_inference_steps': 40, 'width': 768, 'height': 768}, |
| 'ultra': {'num_inference_steps': 50, 'width': 1024, 'height': 1024} |
| } |
| |
| settings = quality_presets.get(quality, quality_presets['balanced']) |
| settings.update(kwargs) |
| |
| pipeline = TrouterImagePipeline() |
| image = pipeline(prompt, **settings) |
| image.save(output_path) |
| |
| logger.info(f"✓ Image saved to {output_path}") |
| return image |
|
|
|
|
| def batch_from_file( |
| prompts_file: str, |
| output_dir: str = "./outputs", |
| **kwargs |
| ) -> List[Image.Image]: |
| """ |
| Generate images from prompts in a text file |
| |
| Args: |
| prompts_file: Text file with one prompt per line |
| output_dir: Where to save images |
| **kwargs: Generation parameters |
| |
| Returns: |
| List of generated images |
| """ |
| with open(prompts_file, 'r') as f: |
| prompts = [line.strip() for line in f if line.strip()] |
| |
| pipeline = TrouterImagePipeline() |
| return pipeline.generate_batch(prompts, output_dir, **kwargs) |
|
|
|
|
| |
| |
| |
|
|
| STYLE_PRESETS = { |
| 'photorealistic': { |
| 'prompt_suffix': ', professional photography, photorealistic, 4k, highly detailed', |
| 'negative_prompt': 'cartoon, anime, painting, illustration, low quality, blurry', |
| 'guidance_scale': 8.5 |
| }, |
| 'artistic': { |
| 'prompt_suffix': ', digital art, concept art, detailed illustration', |
| 'negative_prompt': 'photograph, realistic, blurry, low quality', |
| 'guidance_scale': 7.0 |
| }, |
| 'anime': { |
| 'prompt_suffix': ', anime style, manga, cel shaded, vibrant colors', |
| 'negative_prompt': 'realistic, 3d, photograph, blurry, low quality', |
| 'guidance_scale': 7.5 |
| }, |
| 'oil_painting': { |
| 'prompt_suffix': ', oil painting, painterly, artistic, brushstrokes', |
| 'negative_prompt': 'photograph, digital, 3d render, blurry', |
| 'guidance_scale': 7.5 |
| }, |
| 'cinematic': { |
| 'prompt_suffix': ', cinematic lighting, film still, dramatic, movie scene', |
| 'negative_prompt': 'amateur, low quality, poor lighting, blurry', |
| 'guidance_scale': 8.0 |
| } |
| } |
|
|
|
|
| def generate_with_style( |
| prompt: str, |
| style: str = 'photorealistic', |
| output_path: str = "styled_output.png", |
| **kwargs |
| ) -> Image.Image: |
| """ |
| Generate image with predefined style preset |
| |
| Args: |
| prompt: Base prompt |
| style: Style preset name |
| output_path: Where to save |
| **kwargs: Additional parameters |
| |
| Returns: |
| Generated image |
| """ |
| if style not in STYLE_PRESETS: |
| logger.warning(f"Unknown style: {style}, using photorealistic") |
| style = 'photorealistic' |
| |
| preset = STYLE_PRESETS[style] |
| |
| |
| full_prompt = prompt + preset['prompt_suffix'] |
| kwargs['negative_prompt'] = preset['negative_prompt'] |
| kwargs['guidance_scale'] = preset['guidance_scale'] |
| |
| pipeline = TrouterImagePipeline() |
| image = pipeline(full_prompt, **kwargs) |
| image.save(output_path) |
| |
| logger.info(f"✓ {style.title()} style image saved to {output_path}") |
| return image |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| """Simple command line interface""" |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="Trouter-Imagine-1 Image Generator") |
| parser.add_argument("prompt", type=str, help="Text prompt for generation") |
| parser.add_argument("--output", "-o", type=str, default="output.png", |
| help="Output file path") |
| parser.add_argument("--quality", "-q", type=str, default="balanced", |
| choices=['draft', 'balanced', 'high', 'ultra'], |
| help="Quality preset") |
| parser.add_argument("--style", "-s", type=str, |
| choices=list(STYLE_PRESETS.keys()), |
| help="Style preset") |
| parser.add_argument("--seed", type=int, help="Random seed") |
| parser.add_argument("--width", type=int, default=512, help="Image width") |
| parser.add_argument("--height", type=int, default=512, help="Image height") |
| parser.add_argument("--steps", type=int, default=30, help="Inference steps") |
| parser.add_argument("--guidance", type=float, default=7.5, help="Guidance scale") |
| parser.add_argument("--negative", type=str, help="Negative prompt") |
| |
| args = parser.parse_args() |
| |
| kwargs = { |
| 'width': args.width, |
| 'height': args.height, |
| 'num_inference_steps': args.steps, |
| 'guidance_scale': args.guidance, |
| 'seed': args.seed |
| } |
| |
| if args.negative: |
| kwargs['negative_prompt'] = args.negative |
| |
| if args.style: |
| generate_with_style(args.prompt, args.style, args.output, **kwargs) |
| else: |
| quick_generate(args.prompt, args.output, args.quality, **kwargs) |
|
|
|
|
| if __name__ == "__main__": |
| print("="*70) |
| print("TROUTER-IMAGINE-1 IMAGE GENERATION PIPELINE") |
| print("Apache 2.0 License") |
| print("="*70) |
| print() |
| print("Quick Start Examples:") |
| print() |
| print(" # Python:") |
| print(" from pipeline import TrouterImagePipeline") |
| print(" pipeline = TrouterImagePipeline()") |
| print(" image = pipeline('a beautiful sunset over mountains')") |
| print(" image.save('sunset.png')") |
| print() |
| print(" # Command line:") |
| print(" python pipeline.py 'a cat in a hat' --output cat.png") |
| print(" python pipeline.py 'portrait' --style photorealistic --quality high") |
| print() |
| print("="*70) |
| print() |
| |
| |
| import sys |
| if len(sys.argv) > 1: |
| main() |