rajux75 commited on
Commit
2a5411f
·
verified ·
1 Parent(s): bdbe000

Create services/generation.py

Browse files
Files changed (1) hide show
  1. services/generation.py +188 -0
services/generation.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # services/generation.py
2
+ import torch
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
5
+ from PIL import Image
6
+ import config
7
+ from utils.helpers import decode_base64_image, encode_image_base64, encode_video_base64
8
+ import logging
9
+ import gc # Garbage collector
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # --- Global Model Cache ---
14
+ # Use a dictionary to hold loaded models and tokenizers
15
+ # This allows loading them only once when the app starts.
16
+ model_cache = {}
17
+
18
+ def load_models():
19
+ """Loads all models into the cache. Called at application startup."""
20
+ logger.info("Loading models...")
21
+ try:
22
+ # Text Generation Model
23
+ logger.info(f"Loading text model: {config.TEXT_MODEL_NAME}")
24
+ model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
25
+ model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(config.TEXT_MODEL_NAME).to(config.DEVICE)
26
+ logger.info("Text model loaded.")
27
+
28
+ # Image Generation Model
29
+ logger.info(f"Loading image model: {config.IMAGE_MODEL_NAME}")
30
+ image_pipeline = StableDiffusionPipeline.from_pretrained(
31
+ config.IMAGE_MODEL_NAME,
32
+ torch_dtype=config.DTYPE
33
+ )
34
+ # Optimization: Use a faster scheduler
35
+ image_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(image_pipeline.scheduler.config)
36
+ image_pipeline = image_pipeline.to(config.DEVICE)
37
+ # Optional: Enable attention slicing for lower VRAM usage on GPU
38
+ if config.DEVICE == "cuda":
39
+ try:
40
+ # Requires pip install xformers - uncomment if installed
41
+ # image_pipeline.enable_xformers_memory_efficient_attention()
42
+ pass # Use default if xformers not installed/wanted
43
+ except ImportError:
44
+ logger.warning("xformers not installed. Memory efficient attention not enabled.")
45
+ # image_pipeline.enable_attention_slicing() # Alternative if xformers not available
46
+
47
+ model_cache["image_pipeline"] = image_pipeline
48
+ logger.info("Image model loaded.")
49
+
50
+
51
+ # Video Generation Model
52
+ logger.info(f"Loading video model: {config.VIDEO_MODEL_NAME}")
53
+ video_pipeline = DiffusionPipeline.from_pretrained(
54
+ config.VIDEO_MODEL_NAME,
55
+ torch_dtype=config.DTYPE,
56
+ variant="fp16" if config.DTYPE == torch.float16 else None # Zeroscope often has fp16 variants
57
+ )
58
+ video_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(video_pipeline.scheduler.config)
59
+ video_pipeline.enable_model_cpu_offload() # Crucial for low VRAM environments like Spaces CPU/T4
60
+ # video_pipeline = video_pipeline.to(config.DEVICE) # CPU offload handles device placement
61
+
62
+ model_cache["video_pipeline"] = video_pipeline
63
+ logger.info("Video model loaded.")
64
+
65
+ except Exception as e:
66
+ logger.error(f"Error loading models: {e}", exc_info=True)
67
+ # Depending on policy, you might want to raise the exception
68
+ # or allow the app to start with missing models (endpoints will fail)
69
+ raise # Reraise to prevent app start if essential models fail
70
+
71
+ logger.info("All models loaded successfully.")
72
+
73
+
74
+ def generate_ideas_sync(prompt: str, max_length: int, num_ideas: int) -> List[str]:
75
+ """Synchronous function for text generation (run in thread pool)."""
76
+ tokenizer = model_cache.get("text_tokenizer")
77
+ model = model_cache.get("text_model")
78
+ if not tokenizer or not model:
79
+ raise RuntimeError("Text model not loaded.")
80
+
81
+ # Adjust prompt slightly for better instruction following if needed (e.g., for Flan-T5)
82
+ # input_text = f"Generate {num_ideas} content ideas about: {prompt}"
83
+ input_text = prompt # Keep original prompt based on request model
84
+
85
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(config.DEVICE) # Max input length for model
86
+
87
+ # Generation parameters
88
+ outputs = model.generate(
89
+ **inputs,
90
+ max_length=max_length,
91
+ num_return_sequences=num_ideas,
92
+ do_sample=True, # Use sampling for more diverse ideas
93
+ temperature=0.8,
94
+ top_k=50,
95
+ top_p=0.95,
96
+ no_repeat_ngram_size=2 # Avoid repetitive phrases
97
+ )
98
+
99
+ ideas = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
100
+ # Clean up GPU memory if applicable
101
+ del inputs
102
+ del outputs
103
+ if config.DEVICE == "cuda":
104
+ torch.cuda.empty_cache()
105
+ gc.collect()
106
+ return ideas
107
+
108
+
109
+ def generate_image_sync(prompt: str, negative_prompt: str | None, height: int, width: int, num_inference_steps: int, guidance_scale: float) -> str:
110
+ """Synchronous function for image generation (run in thread pool)."""
111
+ pipeline = model_cache.get("image_pipeline")
112
+ if not pipeline:
113
+ raise RuntimeError("Image pipeline not loaded.")
114
+
115
+ try:
116
+ with torch.no_grad(): # Conserve memory during inference
117
+ result = pipeline(
118
+ prompt=prompt,
119
+ negative_prompt=negative_prompt,
120
+ height=height,
121
+ width=width,
122
+ num_inference_steps=num_inference_steps,
123
+ guidance_scale=guidance_scale,
124
+ # generator=torch.Generator(device=config.DEVICE).manual_seed(seed) # Optional: for reproducibility
125
+ )
126
+ image: Image.Image = result.images[0]
127
+
128
+ # Encode image to base64
129
+ image_base64 = encode_image_base64(image, format="PNG")
130
+
131
+ finally:
132
+ # Clean up GPU memory if applicable
133
+ if config.DEVICE == "cuda":
134
+ torch.cuda.empty_cache()
135
+ gc.collect()
136
+
137
+ return image_base64
138
+
139
+
140
+ def generate_video_sync(
141
+ image_base64: str,
142
+ prompt: str | None,
143
+ motion_bucket_id: int,
144
+ noise_aug_strength: float,
145
+ num_frames: int,
146
+ fps: int,
147
+ num_inference_steps: int,
148
+ guidance_scale: float
149
+ ) -> tuple[str, str]:
150
+ """Synchronous function for video generation (run in thread pool)."""
151
+ pipeline = model_cache.get("video_pipeline")
152
+ if not pipeline:
153
+ raise RuntimeError("Video pipeline not loaded.")
154
+
155
+ input_image = decode_base64_image(image_base64)
156
+
157
+ try:
158
+ with torch.no_grad():
159
+ # CPU offload handles device placement, no need for explicit .to(config.DEVICE)
160
+ video_frames = pipeline(
161
+ input_image,
162
+ prompt=prompt, # Zeroscope uses prompt less directly, more for style maybe
163
+ num_inference_steps=num_inference_steps,
164
+ num_frames=num_frames,
165
+ height=input_image.height, # Match input image size usually
166
+ width=input_image.width,
167
+ guidance_scale=guidance_scale,
168
+ motion_bucket_id=motion_bucket_id,
169
+ noise_aug_strength=noise_aug_strength
170
+ ).frames[0] # Output is often nested [[frame1, frame2...]]
171
+
172
+ # video_frames is usually List[PIL.Image], convert to numpy for encoding
173
+ video_frames_np = [np.array(frame) for frame in video_frames]
174
+
175
+ # Encode video to base64
176
+ video_base64, actual_format = encode_video_base64(video_frames_np, fps=fps, format="MP4") # Request MP4, helper handles fallback
177
+
178
+ finally:
179
+ # Clean up GPU/CPU memory
180
+ # Offloading handles VRAM well, but ensure general RAM is freed
181
+ del input_image
182
+ del video_frames
183
+ del video_frames_np
184
+ if config.DEVICE == "cuda":
185
+ torch.cuda.empty_cache() # Still good practice
186
+ gc.collect()
187
+
188
+ return video_base64, actual_format