Trouter-Library commited on
Commit
3fd5fbc
·
verified ·
1 Parent(s): b30e094

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +551 -0
pipeline.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Trouter-Imagine-1 Complete Pipeline
4
+ Apache 2.0 License
5
+
6
+ This file provides a complete, ready-to-use pipeline for text-to-image generation.
7
+ It includes all necessary components and can be used immediately for generating images.
8
+
9
+ This is the MAIN FILE for using the model - simple and powerful.
10
+ """
11
+
12
+ import torch
13
+ from diffusers import (
14
+ StableDiffusionPipeline,
15
+ DPMSolverMultistepScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ DDIMScheduler
18
+ )
19
+ from PIL import Image
20
+ import os
21
+ from typing import List, Optional, Union, Dict
22
+ import warnings
23
+ import logging
24
+ from pathlib import Path
25
+ import json
26
+ from datetime import datetime
27
+
28
+ logging.basicConfig(
29
+ level=logging.INFO,
30
+ format='%(asctime)s - %(levelname)s - %(message)s'
31
+ )
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class TrouterImagePipeline:
36
+ """
37
+ Complete ready-to-use pipeline for Trouter-Imagine-1
38
+
39
+ This is the main class you should use for image generation.
40
+ It's simple, powerful, and handles everything automatically.
41
+
42
+ Example:
43
+ >>> pipeline = TrouterImagePipeline()
44
+ >>> image = pipeline("a beautiful sunset")
45
+ >>> image.save("sunset.png")
46
+ """
47
+
48
+ # Default base model (you can change this to your custom model once trained)
49
+ DEFAULT_MODEL = "runwayml/stable-diffusion-v1-5"
50
+
51
+ # Can also use these alternatives:
52
+ # "stabilityai/stable-diffusion-2-1"
53
+ # "stabilityai/stable-diffusion-xl-base-1.0"
54
+
55
+ def __init__(
56
+ self,
57
+ model_id: Optional[str] = None,
58
+ device: Optional[str] = None,
59
+ torch_dtype: torch.dtype = torch.float16,
60
+ use_safetensors: bool = True,
61
+ enable_optimizations: bool = True
62
+ ):
63
+ """
64
+ Initialize the Trouter-Imagine-1 pipeline
65
+
66
+ Args:
67
+ model_id: Model to use (defaults to Stable Diffusion 1.5)
68
+ device: Device to use (auto-detected if None)
69
+ torch_dtype: Model precision (float16 for speed, float32 for quality)
70
+ use_safetensors: Use safetensors format (recommended)
71
+ enable_optimizations: Enable memory optimizations
72
+ """
73
+ # Auto-detect device
74
+ if device is None:
75
+ if torch.cuda.is_available():
76
+ device = "cuda"
77
+ logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
78
+ elif torch.backends.mps.is_available():
79
+ device = "mps"
80
+ logger.info("Using Apple Silicon (MPS)")
81
+ else:
82
+ device = "cpu"
83
+ logger.warning("No GPU detected, using CPU (will be slow)")
84
+
85
+ self.device = device
86
+ self.dtype = torch_dtype
87
+ self.model_id = model_id or self.DEFAULT_MODEL
88
+
89
+ logger.info(f"Initializing Trouter-Imagine-1 Pipeline")
90
+ logger.info(f"Model: {self.model_id}")
91
+ logger.info(f"Device: {self.device}")
92
+ logger.info(f"Precision: {self.dtype}")
93
+
94
+ # Load pipeline
95
+ self._load_pipeline(use_safetensors)
96
+
97
+ # Apply optimizations
98
+ if enable_optimizations:
99
+ self._optimize()
100
+
101
+ # Default settings
102
+ self.default_negative = "blurry, low quality, distorted, deformed, ugly, bad anatomy, watermark, signature, text"
103
+
104
+ logger.info("✓ Pipeline ready!")
105
+
106
+ def _load_pipeline(self, use_safetensors: bool):
107
+ """Load the diffusion pipeline"""
108
+ try:
109
+ self.pipe = StableDiffusionPipeline.from_pretrained(
110
+ self.model_id,
111
+ torch_dtype=self.dtype,
112
+ use_safetensors=use_safetensors,
113
+ safety_checker=None, # Disable for flexibility
114
+ requires_safety_checker=False
115
+ )
116
+
117
+ # Move to device
118
+ self.pipe = self.pipe.to(self.device)
119
+
120
+ # Set better scheduler by default
121
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
122
+ self.pipe.scheduler.config
123
+ )
124
+
125
+ logger.info("✓ Model loaded successfully")
126
+
127
+ except Exception as e:
128
+ logger.error(f"Failed to load model: {e}")
129
+ raise
130
+
131
+ def _optimize(self):
132
+ """Apply memory and speed optimizations"""
133
+ logger.info("Applying optimizations...")
134
+
135
+ try:
136
+ # Memory optimizations
137
+ self.pipe.enable_attention_slicing()
138
+ self.pipe.enable_vae_slicing()
139
+ logger.info(" ✓ Memory optimizations enabled")
140
+ except Exception as e:
141
+ logger.warning(f" ⚠ Memory optimization failed: {e}")
142
+
143
+ # Try xformers for even better performance
144
+ try:
145
+ self.pipe.enable_xformers_memory_efficient_attention()
146
+ logger.info(" ✓ xformers enabled (faster generation)")
147
+ except Exception:
148
+ logger.info(" ℹ xformers not available (this is fine)")
149
+
150
+ # Model CPU offload for very limited VRAM
151
+ # Uncomment if you have < 6GB VRAM:
152
+ # self.pipe.enable_model_cpu_offload()
153
+
154
+ def __call__(
155
+ self,
156
+ prompt: Union[str, List[str]],
157
+ negative_prompt: Optional[Union[str, List[str]]] = None,
158
+ width: int = 512,
159
+ height: int = 512,
160
+ num_inference_steps: int = 30,
161
+ guidance_scale: float = 7.5,
162
+ num_images: int = 1,
163
+ seed: Optional[int] = None,
164
+ return_dict: bool = False
165
+ ) -> Union[Image.Image, List[Image.Image], Dict]:
166
+ """
167
+ Generate images from text prompt
168
+
169
+ Args:
170
+ prompt: Text description or list of descriptions
171
+ negative_prompt: What to avoid (uses default if None)
172
+ width: Image width (must be multiple of 8)
173
+ height: Image height (must be multiple of 8)
174
+ num_inference_steps: Quality (20=fast, 30=balanced, 50=quality)
175
+ guidance_scale: Prompt adherence (7.5 is good default)
176
+ num_images: Number of images to generate
177
+ seed: Random seed for reproducibility
178
+ return_dict: Return dictionary with metadata
179
+
180
+ Returns:
181
+ Generated image(s) or dictionary with images and metadata
182
+ """
183
+ # Use default negative prompt if none provided
184
+ if negative_prompt is None:
185
+ negative_prompt = self.default_negative
186
+
187
+ # Set seed if provided
188
+ generator = None
189
+ if seed is not None:
190
+ generator = torch.Generator(device=self.device).manual_seed(seed)
191
+
192
+ # Validate dimensions
193
+ if width % 8 != 0:
194
+ width = (width // 8) * 8
195
+ logger.warning(f"Width adjusted to {width} (must be multiple of 8)")
196
+ if height % 8 != 0:
197
+ height = (height // 8) * 8
198
+ logger.warning(f"Height adjusted to {height} (must be multiple of 8)")
199
+
200
+ # Generate
201
+ logger.info(f"Generating: {prompt[:100]}...")
202
+
203
+ try:
204
+ with torch.autocast(self.device) if self.device == "cuda" else torch.no_grad():
205
+ output = self.pipe(
206
+ prompt=prompt,
207
+ negative_prompt=negative_prompt,
208
+ width=width,
209
+ height=height,
210
+ num_inference_steps=num_inference_steps,
211
+ guidance_scale=guidance_scale,
212
+ num_images_per_prompt=num_images,
213
+ generator=generator
214
+ )
215
+
216
+ images = output.images
217
+ logger.info(f"✓ Generated {len(images)} image(s)")
218
+
219
+ if return_dict:
220
+ return {
221
+ 'images': images,
222
+ 'prompt': prompt,
223
+ 'negative_prompt': negative_prompt,
224
+ 'width': width,
225
+ 'height': height,
226
+ 'steps': num_inference_steps,
227
+ 'guidance': guidance_scale,
228
+ 'seed': seed
229
+ }
230
+
231
+ return images[0] if len(images) == 1 else images
232
+
233
+ except torch.cuda.OutOfMemoryError:
234
+ logger.error("GPU out of memory! Try:")
235
+ logger.error(" 1. Reduce resolution (e.g., 512x512 instead of 1024x1024)")
236
+ logger.error(" 2. Reduce num_images")
237
+ logger.error(" 3. Close other applications")
238
+ raise
239
+ except Exception as e:
240
+ logger.error(f"Generation failed: {e}")
241
+ raise
242
+
243
+ def generate_batch(
244
+ self,
245
+ prompts: List[str],
246
+ output_dir: str = "./outputs",
247
+ **kwargs
248
+ ) -> List[Image.Image]:
249
+ """
250
+ Generate multiple images from different prompts
251
+
252
+ Args:
253
+ prompts: List of text prompts
254
+ output_dir: Directory to save images
255
+ **kwargs: Additional generation parameters
256
+
257
+ Returns:
258
+ List of generated images
259
+ """
260
+ output_path = Path(output_dir)
261
+ output_path.mkdir(parents=True, exist_ok=True)
262
+
263
+ images = []
264
+ logger.info(f"Generating batch of {len(prompts)} images...")
265
+
266
+ for i, prompt in enumerate(prompts):
267
+ logger.info(f" [{i+1}/{len(prompts)}] {prompt[:50]}...")
268
+
269
+ image = self(prompt, **kwargs)
270
+ images.append(image)
271
+
272
+ # Save
273
+ filename = output_path / f"image_{i:04d}.png"
274
+ image.save(filename)
275
+ logger.info(f" ✓ Saved to {filename}")
276
+
277
+ logger.info(f"✓ Batch complete! {len(images)} images in {output_dir}")
278
+ return images
279
+
280
+ def generate_variations(
281
+ self,
282
+ prompt: str,
283
+ num_variations: int = 4,
284
+ **kwargs
285
+ ) -> List[Image.Image]:
286
+ """
287
+ Generate variations of the same prompt (different seeds)
288
+
289
+ Args:
290
+ prompt: Text prompt
291
+ num_variations: Number of variations
292
+ **kwargs: Additional generation parameters
293
+
294
+ Returns:
295
+ List of image variations
296
+ """
297
+ logger.info(f"Generating {num_variations} variations...")
298
+
299
+ images = []
300
+ for i in range(num_variations):
301
+ seed = torch.randint(0, 2**32, (1,)).item()
302
+ image = self(prompt, seed=seed, **kwargs)
303
+ images.append(image)
304
+ logger.info(f" ✓ Variation {i+1}/{num_variations}")
305
+
306
+ return images
307
+
308
+ def set_scheduler(self, scheduler_name: str):
309
+ """
310
+ Change the diffusion scheduler
311
+
312
+ Args:
313
+ scheduler_name: 'dpm' (fast), 'euler' (creative), 'ddim' (stable)
314
+ """
315
+ schedulers = {
316
+ 'dpm': DPMSolverMultistepScheduler,
317
+ 'euler': EulerAncestralDiscreteScheduler,
318
+ 'ddim': DDIMScheduler,
319
+ }
320
+
321
+ if scheduler_name.lower() not in schedulers:
322
+ logger.warning(f"Unknown scheduler: {scheduler_name}")
323
+ return
324
+
325
+ scheduler_class = schedulers[scheduler_name.lower()]
326
+ self.pipe.scheduler = scheduler_class.from_config(
327
+ self.pipe.scheduler.config
328
+ )
329
+ logger.info(f"✓ Scheduler changed to {scheduler_name}")
330
+
331
+ def save_pipeline(self, save_path: str):
332
+ """Save the complete pipeline"""
333
+ self.pipe.save_pretrained(save_path)
334
+ logger.info(f"✓ Pipeline saved to {save_path}")
335
+
336
+ def get_config(self) -> Dict:
337
+ """Get current pipeline configuration"""
338
+ return {
339
+ 'model_id': self.model_id,
340
+ 'device': str(self.device),
341
+ 'dtype': str(self.dtype),
342
+ 'scheduler': self.pipe.scheduler.__class__.__name__,
343
+ 'default_negative_prompt': self.default_negative
344
+ }
345
+
346
+
347
+ # ============================================================================
348
+ # CONVENIENCE FUNCTIONS
349
+ # ============================================================================
350
+
351
+ def quick_generate(
352
+ prompt: str,
353
+ output_path: str = "output.png",
354
+ quality: str = "balanced",
355
+ **kwargs
356
+ ) -> Image.Image:
357
+ """
358
+ Quick one-line image generation
359
+
360
+ Args:
361
+ prompt: What to generate
362
+ output_path: Where to save
363
+ quality: 'draft' (fast), 'balanced', 'high', 'ultra'
364
+ **kwargs: Additional parameters
365
+
366
+ Returns:
367
+ Generated image
368
+
369
+ Example:
370
+ >>> quick_generate("a cat in a hat", "cat.png")
371
+ """
372
+ quality_presets = {
373
+ 'draft': {'num_inference_steps': 15, 'width': 512, 'height': 512},
374
+ 'balanced': {'num_inference_steps': 30, 'width': 512, 'height': 512},
375
+ 'high': {'num_inference_steps': 40, 'width': 768, 'height': 768},
376
+ 'ultra': {'num_inference_steps': 50, 'width': 1024, 'height': 1024}
377
+ }
378
+
379
+ settings = quality_presets.get(quality, quality_presets['balanced'])
380
+ settings.update(kwargs)
381
+
382
+ pipeline = TrouterImagePipeline()
383
+ image = pipeline(prompt, **settings)
384
+ image.save(output_path)
385
+
386
+ logger.info(f"✓ Image saved to {output_path}")
387
+ return image
388
+
389
+
390
+ def batch_from_file(
391
+ prompts_file: str,
392
+ output_dir: str = "./outputs",
393
+ **kwargs
394
+ ) -> List[Image.Image]:
395
+ """
396
+ Generate images from prompts in a text file
397
+
398
+ Args:
399
+ prompts_file: Text file with one prompt per line
400
+ output_dir: Where to save images
401
+ **kwargs: Generation parameters
402
+
403
+ Returns:
404
+ List of generated images
405
+ """
406
+ with open(prompts_file, 'r') as f:
407
+ prompts = [line.strip() for line in f if line.strip()]
408
+
409
+ pipeline = TrouterImagePipeline()
410
+ return pipeline.generate_batch(prompts, output_dir, **kwargs)
411
+
412
+
413
+ # ============================================================================
414
+ # PRESETS AND STYLES
415
+ # ============================================================================
416
+
417
+ STYLE_PRESETS = {
418
+ 'photorealistic': {
419
+ 'prompt_suffix': ', professional photography, photorealistic, 4k, highly detailed',
420
+ 'negative_prompt': 'cartoon, anime, painting, illustration, low quality, blurry',
421
+ 'guidance_scale': 8.5
422
+ },
423
+ 'artistic': {
424
+ 'prompt_suffix': ', digital art, concept art, detailed illustration',
425
+ 'negative_prompt': 'photograph, realistic, blurry, low quality',
426
+ 'guidance_scale': 7.0
427
+ },
428
+ 'anime': {
429
+ 'prompt_suffix': ', anime style, manga, cel shaded, vibrant colors',
430
+ 'negative_prompt': 'realistic, 3d, photograph, blurry, low quality',
431
+ 'guidance_scale': 7.5
432
+ },
433
+ 'oil_painting': {
434
+ 'prompt_suffix': ', oil painting, painterly, artistic, brushstrokes',
435
+ 'negative_prompt': 'photograph, digital, 3d render, blurry',
436
+ 'guidance_scale': 7.5
437
+ },
438
+ 'cinematic': {
439
+ 'prompt_suffix': ', cinematic lighting, film still, dramatic, movie scene',
440
+ 'negative_prompt': 'amateur, low quality, poor lighting, blurry',
441
+ 'guidance_scale': 8.0
442
+ }
443
+ }
444
+
445
+
446
+ def generate_with_style(
447
+ prompt: str,
448
+ style: str = 'photorealistic',
449
+ output_path: str = "styled_output.png",
450
+ **kwargs
451
+ ) -> Image.Image:
452
+ """
453
+ Generate image with predefined style preset
454
+
455
+ Args:
456
+ prompt: Base prompt
457
+ style: Style preset name
458
+ output_path: Where to save
459
+ **kwargs: Additional parameters
460
+
461
+ Returns:
462
+ Generated image
463
+ """
464
+ if style not in STYLE_PRESETS:
465
+ logger.warning(f"Unknown style: {style}, using photorealistic")
466
+ style = 'photorealistic'
467
+
468
+ preset = STYLE_PRESETS[style]
469
+
470
+ # Apply style
471
+ full_prompt = prompt + preset['prompt_suffix']
472
+ kwargs['negative_prompt'] = preset['negative_prompt']
473
+ kwargs['guidance_scale'] = preset['guidance_scale']
474
+
475
+ pipeline = TrouterImagePipeline()
476
+ image = pipeline(full_prompt, **kwargs)
477
+ image.save(output_path)
478
+
479
+ logger.info(f"✓ {style.title()} style image saved to {output_path}")
480
+ return image
481
+
482
+
483
+ # ============================================================================
484
+ # MAIN - COMMAND LINE INTERFACE
485
+ # ============================================================================
486
+
487
+ def main():
488
+ """Simple command line interface"""
489
+ import argparse
490
+
491
+ parser = argparse.ArgumentParser(description="Trouter-Imagine-1 Image Generator")
492
+ parser.add_argument("prompt", type=str, help="Text prompt for generation")
493
+ parser.add_argument("--output", "-o", type=str, default="output.png",
494
+ help="Output file path")
495
+ parser.add_argument("--quality", "-q", type=str, default="balanced",
496
+ choices=['draft', 'balanced', 'high', 'ultra'],
497
+ help="Quality preset")
498
+ parser.add_argument("--style", "-s", type=str,
499
+ choices=list(STYLE_PRESETS.keys()),
500
+ help="Style preset")
501
+ parser.add_argument("--seed", type=int, help="Random seed")
502
+ parser.add_argument("--width", type=int, default=512, help="Image width")
503
+ parser.add_argument("--height", type=int, default=512, help="Image height")
504
+ parser.add_argument("--steps", type=int, default=30, help="Inference steps")
505
+ parser.add_argument("--guidance", type=float, default=7.5, help="Guidance scale")
506
+ parser.add_argument("--negative", type=str, help="Negative prompt")
507
+
508
+ args = parser.parse_args()
509
+
510
+ kwargs = {
511
+ 'width': args.width,
512
+ 'height': args.height,
513
+ 'num_inference_steps': args.steps,
514
+ 'guidance_scale': args.guidance,
515
+ 'seed': args.seed
516
+ }
517
+
518
+ if args.negative:
519
+ kwargs['negative_prompt'] = args.negative
520
+
521
+ if args.style:
522
+ generate_with_style(args.prompt, args.style, args.output, **kwargs)
523
+ else:
524
+ quick_generate(args.prompt, args.output, args.quality, **kwargs)
525
+
526
+
527
+ if __name__ == "__main__":
528
+ print("="*70)
529
+ print("TROUTER-IMAGINE-1 IMAGE GENERATION PIPELINE")
530
+ print("Apache 2.0 License")
531
+ print("="*70)
532
+ print()
533
+ print("Quick Start Examples:")
534
+ print()
535
+ print(" # Python:")
536
+ print(" from pipeline import TrouterImagePipeline")
537
+ print(" pipeline = TrouterImagePipeline()")
538
+ print(" image = pipeline('a beautiful sunset over mountains')")
539
+ print(" image.save('sunset.png')")
540
+ print()
541
+ print(" # Command line:")
542
+ print(" python pipeline.py 'a cat in a hat' --output cat.png")
543
+ print(" python pipeline.py 'portrait' --style photorealistic --quality high")
544
+ print()
545
+ print("="*70)
546
+ print()
547
+
548
+ # Run CLI if arguments provided
549
+ import sys
550
+ if len(sys.argv) > 1:
551
+ main()