SonicaB commited on
Commit
2abc409
Β·
verified Β·
1 Parent(s): f4e50b7

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. fusion-app/app_local.py +33 -69
fusion-app/app_local.py CHANGED
@@ -31,11 +31,18 @@ def _img_to_jpeg_bytes(pil: Image.Image) -> bytes:
31
  pil.convert("RGB").save(buf, format="JPEG", quality=90)
32
  return buf.getvalue()
33
 
 
 
 
 
 
 
 
34
  def clip_api_probs(pil_img, prompts, token):
35
  """
36
- Zero-shot image classification via CLIP using the official client.
37
- Strategy: try pinned model β†’ retry with provider default β†’ fallback to local.
38
- Returns a normalized np.array of shape [len(prompts)].
39
  """
40
  client = InferenceClient(token=token)
41
 
@@ -45,33 +52,26 @@ def clip_api_probs(pil_img, prompts, token):
45
  s = arr.sum()
46
  return (arr / s) if s > 0 else np.ones(len(prompts), dtype=np.float32) / len(prompts)
47
 
48
- # 1) try your pinned checkpoint first
49
- try:
50
- img_bytes = _img_to_jpeg_bytes(pil_img) # ← convert PIL β†’ bytes
51
- res = client.zero_shot_image_classification(
52
- image=img_bytes, # ← pass bytes (not PIL)
53
- candidate_labels=prompts,
54
- hypothesis_template="{}",
55
- model=CLIP_MODEL,
56
- )
57
- return _to_arr(res)
58
- except (StopIteration, HfHubHTTPError, ValueError) as e:
59
- print(f"[WARN] CLIP provider/model unavailable ({e}); retrying with provider default.", flush=True)
60
-
61
- # 2) provider default for the task
62
- try:
63
- img_bytes = _img_to_jpeg_bytes(pil_img) # ← convert again for clarity
64
- res = client.zero_shot_image_classification(
65
- image=img_bytes, # ← pass bytes
66
- candidate_labels=prompts,
67
- hypothesis_template="{}",
68
- model=None,
69
- )
70
- return _to_arr(res)
71
- except (StopIteration, HfHubHTTPError, ValueError) as e:
72
- print(f"[WARN] CLIP default route failed ({e}); falling back to local.", flush=True)
73
- from fusion import clip_image_probs as local_clip
74
- return local_clip(pil_img)
75
 
76
  def _wave_float32_to_wav_bytes(wave_16k: np.ndarray, sr=16000) -> bytes:
77
  samples = (np.clip(wave_16k, -1, 1) * 32767.0).astype(np.int16)
@@ -81,45 +81,9 @@ def _wave_float32_to_wav_bytes(wave_16k: np.ndarray, sr=16000) -> bytes:
81
  return out.getvalue()
82
 
83
  def w2v2_api_embed(wave_16k, token):
84
- """
85
- Feature extraction via the official client.
86
- Strategy: try pinned model β†’ retry with provider default β†’ fallback to local.
87
- Returns a mean-pooled, L2-normalized embedding (np.float32).
88
- """
89
- client = InferenceClient(token=token)
90
-
91
- def _mean_l2(feats):
92
- arr = np.asarray(feats, dtype=np.float32) # [T, D] or [1, T, D]
93
- if arr.ndim == 3:
94
- arr = arr[0]
95
- vec = arr.mean(axis=0)
96
- n = np.linalg.norm(vec) + 1e-8
97
- return (vec / n).astype(np.float32)
98
-
99
- def _feats_with_backoff(model_id):
100
- try:
101
- return client.feature_extraction(audio=wave_16k, model=model_id)
102
- except (HfHubHTTPError, StopIteration) as e:
103
- raise e
104
- except Exception as e:
105
- wav_bytes = _wave_float32_to_wav_bytes(wave_16k)
106
- return client.feature_extraction(audio=wav_bytes, model=model_id)
107
-
108
- try:
109
- feats = _feats_with_backoff(W2V2_MODEL)
110
- return _mean_l2(feats)
111
- except (StopIteration, HfHubHTTPError) as e:
112
- print(f"[WARN] W2V2 provider/model unavailable ({e}); retrying with provider default.", flush=True)
113
- pass
114
-
115
- try:
116
- feats = _feats_with_backoff(None)
117
- return _mean_l2(feats)
118
- except (StopIteration, HfHubHTTPError) as e:
119
- print(f"[WARN] W2V2 default route failed ({e}); falling back to local.", flush=True)
120
- from fusion import wav2vec2_embed_energy
121
- emb, _ = wav2vec2_embed_energy(wave_16k)
122
- return emb
123
 
124
  _PROTO_EMBS_API = None
125
 
 
31
  pil.convert("RGB").save(buf, format="JPEG", quality=90)
32
  return buf.getvalue()
33
 
34
+ CLIP_CANDIDATES = [
35
+ CLIP_MODEL,
36
+ "openai/clip-vit-large-patch14-336",
37
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
38
+ None,
39
+ ]
40
+
41
  def clip_api_probs(pil_img, prompts, token):
42
  """
43
+ Zero-shot image classification via InferenceClient.
44
+ Try pinned β†’ candidates β†’ provider default β†’ fallback LOCAL.
45
+ Returns np.array[K] normalized.
46
  """
47
  client = InferenceClient(token=token)
48
 
 
52
  s = arr.sum()
53
  return (arr / s) if s > 0 else np.ones(len(prompts), dtype=np.float32) / len(prompts)
54
 
55
+ img_bytes = _img_to_jpeg_bytes(pil_img) # PIL -> bytes
56
+
57
+ last_err = None
58
+ for mid in CLIP_CANDIDATES:
59
+ try:
60
+ res = client.zero_shot_image_classification(
61
+ image=img_bytes, # bytes (compatible across hub versions)
62
+ candidate_labels=prompts,
63
+ hypothesis_template="{}",
64
+ model=mid,
65
+ )
66
+ return _to_arr(res)
67
+ except (HfHubHTTPError, StopIteration, ValueError) as e:
68
+ print(f"[WARN] CLIP provider/model {mid or 'DEFAULT'} failed ({e}); trying next.", flush=True)
69
+ last_err = e
70
+
71
+ # Final fallback: LOCAL CLIP to keep UX working
72
+ print(f"[WARN] CLIP all provider routes failed ({last_err}); falling back to LOCAL.", flush=True)
73
+ from fusion import clip_image_probs as local_clip
74
+ return local_clip(pil_img)
 
 
 
 
 
 
 
75
 
76
  def _wave_float32_to_wav_bytes(wave_16k: np.ndarray, sr=16000) -> bytes:
77
  samples = (np.clip(wave_16k, -1, 1) * 32767.0).astype(np.int16)
 
81
  return out.getvalue()
82
 
83
  def w2v2_api_embed(wave_16k, token):
84
+ from fusion import wav2vec2_embed_energy
85
+ emb, _ = wav2vec2_embed_energy(wave_16k)
86
+ return emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  _PROTO_EMBS_API = None
89