rajux75 commited on
Commit
5f7ce0f
·
verified ·
1 Parent(s): 204e1ea

Update services/generation.py

Browse files
Files changed (1) hide show
  1. services/generation.py +38 -21
services/generation.py CHANGED
@@ -1,13 +1,14 @@
1
  # services/generation.py
2
  import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
- from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
5
  from PIL import Image
6
  import config
7
  from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64
8
  import logging
9
  import gc # Garbage collector
10
  from typing import List
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
@@ -27,27 +28,43 @@ def load_models():
27
  model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(config.TEXT_MODEL_NAME).to(config.DEVICE)
28
  logger.info("Text model loaded.")
29
 
30
- # Image Generation Model
31
- logger.info(f"Loading image model: {config.IMAGE_MODEL_NAME}")
32
- image_pipeline = StableDiffusionPipeline.from_pretrained(
33
- config.IMAGE_MODEL_NAME,
34
- torch_dtype=config.DTYPE
35
- )
36
- # Optimization: Use a faster scheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
38
- image_pipeline = image_pipeline.to(config.DEVICE)
39
- # Optional: Enable attention slicing for lower VRAM usage on GPU
40
- if config.DEVICE == "cuda":
41
- try:
42
- # Requires pip install xformers - uncomment if installed
43
- # image_pipeline.enable_xformers_memory_efficient_attention()
44
- pass # Use default if xformers not installed/wanted
45
- except ImportError:
46
- logger.warning("xformers not installed. Memory efficient attention not enabled.")
47
- # image_pipeline.enable_attention_slicing() # Alternative if xformers not available
48
-
49
- model_cache["image_pipeline"] = image_pipeline
50
- logger.info("Image model loaded.")
51
 
52
 
53
  # Video Generation Model
 
1
  # services/generation.py
2
  import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
4
  from PIL import Image
5
  import config
6
  from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64
7
  import logging
8
  import gc # Garbage collector
9
  from typing import List
10
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler # Import LCMScheduler
11
+ from peft import PeftConfig # Import PeftConfig (if needed, usually handled by load_lora_weights)
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
28
  model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(config.TEXT_MODEL_NAME).to(config.DEVICE)
29
  logger.info("Text model loaded.")
30
 
31
+ # --- Image Generation Model ---
32
+ logger.info(f"Loading image model: {config.IMAGE_MODEL_NAME}")
33
+ image_pipeline = StableDiffusionPipeline.from_pretrained(
34
+ config.IMAGE_MODEL_NAME,
35
+ torch_dtype=config.DTYPE
36
+ )
37
+
38
+ # --- Load LCM LoRA ---
39
+ try:
40
+ logger.info(f"Loading LCM LoRA: {config.IMAGE_LCM_LORA_NAME}")
41
+ # Load LoRA weights directly into the pipeline
42
+ image_pipeline.load_lora_weights(config.IMAGE_LCM_LORA_NAME)
43
+ # Fuse LoRA for potential speedup (optional, test impact)
44
+ # image_pipeline.fuse_lora()
45
+ logger.info("LCM LoRA loaded successfully.")
46
+
47
+ # --- IMPORTANT: Set LCM Scheduler ---
48
+ image_pipeline.scheduler = LCMScheduler.from_config(image_pipeline.scheduler.config)
49
+ logger.info("Switched scheduler to LCMScheduler.")
50
+
51
+ except Exception as e:
52
+ logger.warning(f"Could not load or apply LCM LoRA '{config.IMAGE_LCM_LORA_NAME}'. Falling back to base model scheduler. Error: {e}", exc_info=True)
53
+ # Fallback to a standard fast scheduler if LCM fails
54
  image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
55
+
56
+
57
+ image_pipeline = image_pipeline.to(config.DEVICE)
58
+ if config.DEVICE == "cuda":
59
+ try:
60
+ # image_pipeline.enable_xformers_memory_efficient_attention()
61
+ pass
62
+ except ImportError:
63
+ logger.warning("xformers not installed...")
64
+ # image_pipeline.enable_attention_slicing()
65
+
66
+ model_cache["image_pipeline"] = image_pipeline
67
+ logger.info("Image model setup complete.")
68
 
69
 
70
  # Video Generation Model