SonicaB commited on
Commit
f4e50b7
·
verified ·
1 Parent(s): 7c217cb

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. fusion-app/app_local.py +21 -11
fusion-app/app_local.py CHANGED
@@ -9,7 +9,7 @@ from pydub import AudioSegment
9
  from utils_media import video_to_frame_audio, load_audio_16k, log_inference
10
  from fusion import clip_image_probs, wav2vec2_embed_energy, wav2vec2_zero_shot_probs, audio_prior_from_rms, fuse_probs, top1_label_from_probs
11
  from fusion import _ensure_audio_prototypes, _proto_embs
12
-
13
 
14
  HERE = Path(__file__).parent
15
  lables_PATH = HERE / "labels.json"
@@ -45,28 +45,30 @@ def clip_api_probs(pil_img, prompts, token):
45
  s = arr.sum()
46
  return (arr / s) if s > 0 else np.ones(len(prompts), dtype=np.float32) / len(prompts)
47
 
 
48
  try:
 
49
  res = client.zero_shot_image_classification(
50
- image=pil_img,
51
  candidate_labels=prompts,
52
  hypothesis_template="{}",
53
  model=CLIP_MODEL,
54
  )
55
  return _to_arr(res)
56
- except (StopIteration, HfHubHTTPError) as e:
57
-
58
  print(f"[WARN] CLIP provider/model unavailable ({e}); retrying with provider default.", flush=True)
59
- pass
60
 
 
61
  try:
 
62
  res = client.zero_shot_image_classification(
63
- image=pil_img,
64
  candidate_labels=prompts,
65
  hypothesis_template="{}",
66
  model=None,
67
  )
68
  return _to_arr(res)
69
- except (StopIteration, HfHubHTTPError) as e:
70
  print(f"[WARN] CLIP default route failed ({e}); falling back to local.", flush=True)
71
  from fusion import clip_image_probs as local_clip
72
  return local_clip(pil_img)
@@ -94,15 +96,24 @@ def w2v2_api_embed(wave_16k, token):
94
  n = np.linalg.norm(vec) + 1e-8
95
  return (vec / n).astype(np.float32)
96
 
 
 
 
 
 
 
 
 
 
97
  try:
98
- feats = client.feature_extraction(audio=wave_16k, model=W2V2_MODEL)
99
  return _mean_l2(feats)
100
  except (StopIteration, HfHubHTTPError) as e:
101
  print(f"[WARN] W2V2 provider/model unavailable ({e}); retrying with provider default.", flush=True)
102
  pass
103
 
104
  try:
105
- feats = client.feature_extraction(audio=wave_16k, model=None)
106
  return _mean_l2(feats)
107
  except (StopIteration, HfHubHTTPError) as e:
108
  print(f"[WARN] W2V2 default route failed ({e}); falling back to local.", flush=True)
@@ -337,8 +348,7 @@ def predict_video(video, alpha=0.7):
337
  # ============= Gradio Interface =============
338
  # Only create demo if not being imported for testing
339
  # Check for pytest in sys.modules to detect test environment
340
- import sys
341
- import os
342
  _is_testing = 'pytest' in sys.modules or os.getenv('PYTEST_CURRENT_TEST') is not None
343
 
344
  # Always create demo for HF Spaces, but skip during pytest
 
9
  from utils_media import video_to_frame_audio, load_audio_16k, log_inference
10
  from fusion import clip_image_probs, wav2vec2_embed_energy, wav2vec2_zero_shot_probs, audio_prior_from_rms, fuse_probs, top1_label_from_probs
11
  from fusion import _ensure_audio_prototypes, _proto_embs
12
+ import sys
13
 
14
  HERE = Path(__file__).parent
15
  lables_PATH = HERE / "labels.json"
 
45
  s = arr.sum()
46
  return (arr / s) if s > 0 else np.ones(len(prompts), dtype=np.float32) / len(prompts)
47
 
48
+ # 1) try your pinned checkpoint first
49
  try:
50
+ img_bytes = _img_to_jpeg_bytes(pil_img) # ← convert PIL → bytes
51
  res = client.zero_shot_image_classification(
52
+ image=img_bytes, # ← pass bytes (not PIL)
53
  candidate_labels=prompts,
54
  hypothesis_template="{}",
55
  model=CLIP_MODEL,
56
  )
57
  return _to_arr(res)
58
+ except (StopIteration, HfHubHTTPError, ValueError) as e:
 
59
  print(f"[WARN] CLIP provider/model unavailable ({e}); retrying with provider default.", flush=True)
 
60
 
61
+ # 2) provider default for the task
62
  try:
63
+ img_bytes = _img_to_jpeg_bytes(pil_img) # ← convert again for clarity
64
  res = client.zero_shot_image_classification(
65
+ image=img_bytes, # ← pass bytes
66
  candidate_labels=prompts,
67
  hypothesis_template="{}",
68
  model=None,
69
  )
70
  return _to_arr(res)
71
+ except (StopIteration, HfHubHTTPError, ValueError) as e:
72
  print(f"[WARN] CLIP default route failed ({e}); falling back to local.", flush=True)
73
  from fusion import clip_image_probs as local_clip
74
  return local_clip(pil_img)
 
96
  n = np.linalg.norm(vec) + 1e-8
97
  return (vec / n).astype(np.float32)
98
 
99
+ def _feats_with_backoff(model_id):
100
+ try:
101
+ return client.feature_extraction(audio=wave_16k, model=model_id)
102
+ except (HfHubHTTPError, StopIteration) as e:
103
+ raise e
104
+ except Exception as e:
105
+ wav_bytes = _wave_float32_to_wav_bytes(wave_16k)
106
+ return client.feature_extraction(audio=wav_bytes, model=model_id)
107
+
108
  try:
109
+ feats = _feats_with_backoff(W2V2_MODEL)
110
  return _mean_l2(feats)
111
  except (StopIteration, HfHubHTTPError) as e:
112
  print(f"[WARN] W2V2 provider/model unavailable ({e}); retrying with provider default.", flush=True)
113
  pass
114
 
115
  try:
116
+ feats = _feats_with_backoff(None)
117
  return _mean_l2(feats)
118
  except (StopIteration, HfHubHTTPError) as e:
119
  print(f"[WARN] W2V2 default route failed ({e}); falling back to local.", flush=True)
 
348
  # ============= Gradio Interface =============
349
  # Only create demo if not being imported for testing
350
  # Check for pytest in sys.modules to detect test environment
351
+
 
352
  _is_testing = 'pytest' in sys.modules or os.getenv('PYTEST_CURRENT_TEST') is not None
353
 
354
  # Always create demo for HF Spaces, but skip during pytest