Spaces:
Sleeping
Sleeping
Update services/generation.py
Browse files- services/generation.py +38 -21
services/generation.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
# services/generation.py
|
| 2 |
import torch
|
| 3 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 4 |
-
from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
|
| 5 |
from PIL import Image
|
| 6 |
import config
|
| 7 |
from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64
|
| 8 |
import logging
|
| 9 |
import gc # Garbage collector
|
| 10 |
from typing import List
|
|
|
|
|
|
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
|
@@ -27,27 +28,43 @@ def load_models():
|
|
| 27 |
model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(config.TEXT_MODEL_NAME).to(config.DEVICE)
|
| 28 |
logger.info("Text model loaded.")
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
|
| 52 |
|
| 53 |
# Video Generation Model
|
|
|
|
| 1 |
# services/generation.py
|
| 2 |
import torch
|
| 3 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
import config
|
| 6 |
from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64
|
| 7 |
import logging
|
| 8 |
import gc # Garbage collector
|
| 9 |
from typing import List
|
| 10 |
+
from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler # Import LCMScheduler
|
| 11 |
+
from peft import PeftConfig # Import PeftConfig (if needed, usually handled by load_lora_weights)
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
|
|
|
| 28 |
model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(config.TEXT_MODEL_NAME).to(config.DEVICE)
|
| 29 |
logger.info("Text model loaded.")
|
| 30 |
|
| 31 |
+
# --- Image Generation Model ---
|
| 32 |
+
logger.info(f"Loading image model: {config.IMAGE_MODEL_NAME}")
|
| 33 |
+
image_pipeline = StableDiffusionPipeline.from_pretrained(
|
| 34 |
+
config.IMAGE_MODEL_NAME,
|
| 35 |
+
torch_dtype=config.DTYPE
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# --- Load LCM LoRA ---
|
| 39 |
+
try:
|
| 40 |
+
logger.info(f"Loading LCM LoRA: {config.IMAGE_LCM_LORA_NAME}")
|
| 41 |
+
# Load LoRA weights directly into the pipeline
|
| 42 |
+
image_pipeline.load_lora_weights(config.IMAGE_LCM_LORA_NAME)
|
| 43 |
+
# Fuse LoRA for potential speedup (optional, test impact)
|
| 44 |
+
# image_pipeline.fuse_lora()
|
| 45 |
+
logger.info("LCM LoRA loaded successfully.")
|
| 46 |
+
|
| 47 |
+
# --- IMPORTANT: Set LCM Scheduler ---
|
| 48 |
+
image_pipeline.scheduler = LCMScheduler.from_config(image_pipeline.scheduler.config)
|
| 49 |
+
logger.info("Switched scheduler to LCMScheduler.")
|
| 50 |
+
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.warning(f"Could not load or apply LCM LoRA '{config.IMAGE_LCM_LORA_NAME}'. Falling back to base model scheduler. Error: {e}", exc_info=True)
|
| 53 |
+
# Fallback to a standard fast scheduler if LCM fails
|
| 54 |
image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
image_pipeline = image_pipeline.to(config.DEVICE)
|
| 58 |
+
if config.DEVICE == "cuda":
|
| 59 |
+
try:
|
| 60 |
+
# image_pipeline.enable_xformers_memory_efficient_attention()
|
| 61 |
+
pass
|
| 62 |
+
except ImportError:
|
| 63 |
+
logger.warning("xformers not installed...")
|
| 64 |
+
# image_pipeline.enable_attention_slicing()
|
| 65 |
+
|
| 66 |
+
model_cache["image_pipeline"] = image_pipeline
|
| 67 |
+
logger.info("Image model setup complete.")
|
| 68 |
|
| 69 |
|
| 70 |
# Video Generation Model
|