SonicaB's picture
Upload folder using huggingface_hub
edf6793 verified
import types
import numpy as np
from PIL import Image
import importlib
import builtins
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
app = importlib.import_module("app_local")
# Utility to get function by trying multiple possible names
def get_fn(mod, *names):
for n in names:
if hasattr(mod, n):
return getattr(mod, n)
raise AttributeError(f"None of {names} found in {mod.__name__}")
predict_video = get_fn(app, "predict_video", "predict_vid")
predict_image_audio = get_fn(app, "predict_image_audio")
def test_predict_image_audio_fuses_correctly(monkeypatch):
K = len(app.lables)
# deterministic distributions
p_img = np.zeros(K); p_img[0] = 1.0 # image votes class 0
p_aud = np.zeros(K); p_aud[1] = 1.0 # audio votes class 1
# Monkeypatch heavy parts to lightweight stubs
monkeypatch.setattr(app, "clip_image_probs", lambda img, **kw: p_img, raising=True)
if hasattr(app, "wav2vec2_zero_shot_probs"):
monkeypatch.setattr(app, "wav2vec2_zero_shot_probs", lambda wave, **kw: p_aud, raising=True)
if hasattr(app, "audio_prior_from_rms"):
monkeypatch.setattr(app, "audio_prior_from_rms", lambda rms: p_aud, raising=False)
if hasattr(app, "wav2vec2_embed_energy"):
monkeypatch.setattr(app, "wav2vec2_embed_energy", lambda wave: (np.zeros(768, dtype=np.float32), 0.5), raising=True)
monkeypatch.setattr(app, "load_audio_16k", lambda path: np.zeros(16000, dtype=np.float32), raising=True)
monkeypatch.setattr(app, "log_inference", lambda **kw: None, raising=False) # no file writes
# dummy inputs
dummy_img = Image.new("RGB", (64, 64), color=128)
dummy_audio_path = "dummy.wav"
pred_hi, probs_hi, _ = predict_image_audio(dummy_img, dummy_audio_path, 0.9)
pred_lo, probs_lo, _ = predict_image_audio(dummy_img, dummy_audio_path, 0.1)
# Indexes for label 0/1
idx0 = 0
idx1 = 1
assert probs_hi[app.lables[idx0]] > probs_hi[app.lables[idx1]]
assert probs_lo[app.lables[idx1]] > probs_lo[app.lables[idx0]]
assert 0.99 <= sum(probs_hi.values()) <= 1.01
assert 0.99 <= sum(probs_lo.values()) <= 1.01
assert isinstance(pred_hi, str) and isinstance(pred_lo, str)
def test_predict_video_path_runs_with_stubs(monkeypatch):
K = len(app.lables)
p_img = np.zeros(K); p_img[-1] = 1.0 # image favors last class
p_aud = np.zeros(K); p_aud[0] = 1.0 # audio favors first class
# Stub frame extractor to avoid ffmpeg
frames = [Image.new("RGB", (32, 32), c) for c in [(255,0,0)]*5]
wave = np.zeros(16000, dtype=np.float32)
meta = {"n_frames":5, "fps_used":1.0, "duration_s":5.0}
# Mock the video_to_frame_audio function that's imported from utils_media
monkeypatch.setattr(app, "video_to_frame_audio", lambda v, **kw: (frames, wave, meta), raising=True)
# Stub models
monkeypatch.setattr(app, "clip_image_probs", lambda pil, **kw: p_img, raising=True)
if hasattr(app, "wav2vec2_zero_shot_probs"):
monkeypatch.setattr(app, "wav2vec2_zero_shot_probs", lambda w, **kw: p_aud, raising=True)
if hasattr(app, "wav2vec2_embed_energy"):
monkeypatch.setattr(app, "wav2vec2_embed_energy", lambda w: (np.zeros(768, dtype=np.float32), 0.3), raising=True)
monkeypatch.setattr(app, "log_inference", lambda **kw: None, raising=False)
# Call
pred, probs, lat = predict_video("dummy.mp4", 0.5)
# Checks
assert isinstance(pred, str)
assert set(probs.keys()) == set(app.lables)
assert "t_total_ms" in lat and "n_frames" in lat or "t_total_ms" in lat