Spaces:
Sleeping
Sleeping
Revert from_tf=True for Flan-T5 loading
Browse files- services/generation.py +6 -6
services/generation.py
CHANGED
|
@@ -37,12 +37,12 @@ def load_models():
|
|
| 37 |
# Load tokenizer associated with the text model
|
| 38 |
model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
|
| 39 |
# Load the sequence-to-sequence language model
|
| 40 |
-
#
|
| 41 |
model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(
|
| 42 |
-
config.TEXT_MODEL_NAME
|
| 43 |
-
from_tf=True
|
| 44 |
).to(config.DEVICE) # Move model to the configured device (CPU or CUDA)
|
| 45 |
-
logger.info(f"Text model '{config.TEXT_MODEL_NAME}' loaded successfully (
|
| 46 |
|
| 47 |
# --- 2. Image Generation Model (Base Pipeline) ---
|
| 48 |
logger.info(f"Loading base image generation model: {config.IMAGE_MODEL_NAME}")
|
|
@@ -377,8 +377,8 @@ def generate_video_sync(
|
|
| 377 |
finally:
|
| 378 |
# --- Resource Cleanup ---
|
| 379 |
del input_image # Delete decoded input image
|
| 380 |
-
del video_frames_pil # Delete list of PIL frames
|
| 381 |
-
del video_frames_np # Delete list of numpy frames
|
| 382 |
# Clear CUDA cache if applicable
|
| 383 |
if config.DEVICE == "cuda":
|
| 384 |
torch.cuda.empty_cache()
|
|
|
|
| 37 |
# Load tokenizer associated with the text model
|
| 38 |
model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
|
| 39 |
# Load the sequence-to-sequence language model
|
| 40 |
+
# Assuming PyTorch weights (.bin or .safetensors) are available for the model.
|
| 41 |
model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(
|
| 42 |
+
config.TEXT_MODEL_NAME
|
| 43 |
+
# REMOVED: from_tf=True - Attempt to load PyTorch weights directly.
|
| 44 |
).to(config.DEVICE) # Move model to the configured device (CPU or CUDA)
|
| 45 |
+
logger.info(f"Text model '{config.TEXT_MODEL_NAME}' loaded successfully (using PyTorch weights) onto {config.DEVICE}.")
|
| 46 |
|
| 47 |
# --- 2. Image Generation Model (Base Pipeline) ---
|
| 48 |
logger.info(f"Loading base image generation model: {config.IMAGE_MODEL_NAME}")
|
|
|
|
| 377 |
finally:
|
| 378 |
# --- Resource Cleanup ---
|
| 379 |
del input_image # Delete decoded input image
|
| 380 |
+
if 'video_frames_pil' in locals(): del video_frames_pil # Delete list of PIL frames if it exists
|
| 381 |
+
if 'video_frames_np' in locals(): del video_frames_np # Delete list of numpy frames if it exists
|
| 382 |
# Clear CUDA cache if applicable
|
| 383 |
if config.DEVICE == "cuda":
|
| 384 |
torch.cuda.empty_cache()
|