rajux75 commited on
Commit
12eb42e
·
verified ·
1 Parent(s): bf7d351

Revert from_tf=True for Flan-T5 loading

Browse files
Files changed (1) hide show
  1. services/generation.py +6 -6
services/generation.py CHANGED
@@ -37,12 +37,12 @@ def load_models():
37
  # Load tokenizer associated with the text model
38
  model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
39
  # Load the sequence-to-sequence language model
40
- # IMPORTANT: Add from_tf=True if the primary weights are TensorFlow format (like google/flan-t5-base)
41
  model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(
42
- config.TEXT_MODEL_NAME,
43
- from_tf=True # Required for google/flan-t5-base which has tf_model.h5
44
  ).to(config.DEVICE) # Move model to the configured device (CPU or CUDA)
45
- logger.info(f"Text model '{config.TEXT_MODEL_NAME}' loaded successfully (from TF weights if applicable) onto {config.DEVICE}.")
46
 
47
  # --- 2. Image Generation Model (Base Pipeline) ---
48
  logger.info(f"Loading base image generation model: {config.IMAGE_MODEL_NAME}")
@@ -377,8 +377,8 @@ def generate_video_sync(
377
  finally:
378
  # --- Resource Cleanup ---
379
  del input_image # Delete decoded input image
380
- del video_frames_pil # Delete list of PIL frames
381
- del video_frames_np # Delete list of numpy frames
382
  # Clear CUDA cache if applicable
383
  if config.DEVICE == "cuda":
384
  torch.cuda.empty_cache()
 
37
  # Load tokenizer associated with the text model
38
  model_cache["text_tokenizer"] = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME)
39
  # Load the sequence-to-sequence language model
40
+ # Assuming PyTorch weights (.bin or .safetensors) are available for the model.
41
  model_cache["text_model"] = AutoModelForSeq2SeqLM.from_pretrained(
42
+ config.TEXT_MODEL_NAME
43
+ # REMOVED: from_tf=True - Attempt to load PyTorch weights directly.
44
  ).to(config.DEVICE) # Move model to the configured device (CPU or CUDA)
45
+ logger.info(f"Text model '{config.TEXT_MODEL_NAME}' loaded successfully (using PyTorch weights) onto {config.DEVICE}.")
46
 
47
  # --- 2. Image Generation Model (Base Pipeline) ---
48
  logger.info(f"Loading base image generation model: {config.IMAGE_MODEL_NAME}")
 
377
  finally:
378
  # --- Resource Cleanup ---
379
  del input_image # Delete decoded input image
380
+ if 'video_frames_pil' in locals(): del video_frames_pil # Delete list of PIL frames if it exists
381
+ if 'video_frames_np' in locals(): del video_frames_np # Delete list of numpy frames if it exists
382
  # Clear CUDA cache if applicable
383
  if config.DEVICE == "cuda":
384
  torch.cuda.empty_cache()