HARRY07979 commited on
Commit
107304c
·
verified ·
1 Parent(s): 937c9b0

Create LiteVisionPipeline.py

Browse files
Files changed (1) hide show
  1. LiteVisionPipeline.py +44 -0
LiteVisionPipeline.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline, LCMScheduler
3
+
4
+
5
+ class LiteVisionPipeline(StableDiffusionPipeline):
6
+ """
7
+ LiteVisionPipeline v1
8
+ - Custom pipeline for LiteVision models
9
+ - LCM scheduler preset
10
+ - Low VRAM & fast inference defaults
11
+ - Compatible with DiffusionPipeline ecosystem
12
+ """
13
+
14
+ def __init__(self, *args, **kwargs):
15
+ super().__init__(*args, **kwargs)
16
+
17
+ # ---- Scheduler preset (LCM) ----
18
+ self.scheduler = LCMScheduler.from_config(self.scheduler.config)
19
+
20
+ # ---- Memory optimizations ----
21
+ self.enable_attention_slicing()
22
+ self.enable_vae_slicing()
23
+
24
+ @torch.inference_mode()
25
+ def __call__(
26
+ self,
27
+ prompt,
28
+ negative_prompt=None,
29
+ num_inference_steps: int = 6,
30
+ guidance_scale: float = 1.5,
31
+ **kwargs,
32
+ ):
33
+ """
34
+ LiteVision default generation call
35
+ - Optimized for LCM-style inference
36
+ """
37
+
38
+ return super().__call__(
39
+ prompt=prompt,
40
+ negative_prompt=negative_prompt,
41
+ num_inference_steps=num_inference_steps,
42
+ guidance_scale=guidance_scale,
43
+ **kwargs,
44
+ )