Ryukijano commited on
Commit
6262a36
·
1 Parent(s): 753e6a3

Fix HF Space build: Remove Predict git dependency and handle gracefully

Browse files
Files changed (4) hide show
  1. app.py +5 -2
  2. predict_backend.py +9 -3
  3. requirements.txt +0 -1
  4. space_backend.py +5 -1
app.py CHANGED
@@ -357,8 +357,11 @@ def run_pipeline_action(input_video_path):
357
  def run_predict_action(pipeline_payload, input_video_path, selection):
358
  if not pipeline_payload:
359
  raise gr.Error("Run BADAS + Reason before Predict.")
360
- predict_payload, merged_payload = run_predict_only(pipeline_payload, selection=selection, predict_model_name=PREDICT_MODEL_NAME)
361
- return build_outputs(merged_payload, "Predict completed.", input_video_path, predict_payload)
 
 
 
362
 
363
 
364
  def safe_warmup_message():
 
357
  def run_predict_action(pipeline_payload, input_video_path, selection):
358
  if not pipeline_payload:
359
  raise gr.Error("Run BADAS + Reason before Predict.")
360
+ try:
361
+ predict_payload, merged_payload = run_predict_only(pipeline_payload, selection=selection, predict_model_name=PREDICT_MODEL_NAME)
362
+ return build_outputs(merged_payload, "Predict completed.", input_video_path, predict_payload)
363
+ except Exception as e:
364
+ raise gr.Error(f"Predict failed: {e}")
365
 
366
 
367
  def safe_warmup_message():
predict_backend.py CHANGED
@@ -193,8 +193,11 @@ def build_cache_key(source_video_path, badas_context, reason_context, mode, mode
193
 
194
  @lru_cache(maxsize=2)
195
  def get_predict_inference(model_name, output_root_str, disable_guardrails=True):
196
- from cosmos_predict2.config import SetupArguments
197
- from cosmos_predict2.inference import Inference
 
 
 
198
 
199
  setup_args = SetupArguments(
200
  output_dir=Path(output_root_str),
@@ -241,7 +244,10 @@ def prepare_conditioning_input(source_video_path, badas_context, reason_context,
241
 
242
 
243
  def execute_predict_generation(output_root, model_name, sample_name, conditioning_path, prompt):
244
- from cosmos_predict2.config import InferenceArguments
 
 
 
245
 
246
  inference = get_predict_inference(model_name, str(output_root), True)
247
  inference_args = InferenceArguments(
 
193
 
194
  @lru_cache(maxsize=2)
195
  def get_predict_inference(model_name, output_root_str, disable_guardrails=True):
196
+ try:
197
+ from cosmos_predict2.config import SetupArguments
198
+ from cosmos_predict2.inference import Inference
199
+ except ImportError:
200
+ raise RuntimeError("Cosmos Predict is not installed in this environment.")
201
 
202
  setup_args = SetupArguments(
203
  output_dir=Path(output_root_str),
 
244
 
245
 
246
  def execute_predict_generation(output_root, model_name, sample_name, conditioning_path, prompt):
247
+ try:
248
+ from cosmos_predict2.config import InferenceArguments
249
+ except ImportError:
250
+ raise RuntimeError("Cosmos Predict is not installed in this environment.")
251
 
252
  inference = get_predict_inference(model_name, str(output_root), True)
253
  inference_args = InferenceArguments(
requirements.txt CHANGED
@@ -13,4 +13,3 @@ accelerate
13
  torch
14
  torchvision
15
  albumentations
16
- git+https://github.com/nvidia-cosmos/cosmos-predict2.5.git
 
13
  torch
14
  torchvision
15
  albumentations
 
space_backend.py CHANGED
@@ -90,7 +90,11 @@ def preload_runtime(preload_badas=True, preload_reason=True, preload_predict=Fal
90
  steps.append(f"Reason model ready: {reason_model_name}")
91
  if preload_predict:
92
  get_predict_inference(predict_model_name, str(PREDICT_OUTPUT_ROOT), True)
93
- steps.append(f"Predict model ready: {predict_model_name}")
 
 
 
 
94
  return "\n".join(steps)
95
 
96
 
 
90
  steps.append(f"Reason model ready: {reason_model_name}")
91
  if preload_predict:
92
  get_predict_inference(predict_model_name, str(PREDICT_OUTPUT_ROOT), True)
93
+ try:
94
+ get_predict_inference(predict_model_name, str(PREDICT_OUTPUT_ROOT), True)
95
+ steps.append(f"Predict model ready: {predict_model_name}")
96
+ except Exception as e:
97
+ steps.append(f"Predict model skipped: {e}")
98
  return "\n".join(steps)
99
 
100