File size: 3,634 Bytes
edf6793
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import types
import numpy as np
from PIL import Image
import importlib
import builtins
import sys
from pathlib import Path


sys.path.insert(0, str(Path(__file__).parent.parent))

app = importlib.import_module("app_local")

# Utility to get function by trying multiple possible names
def get_fn(mod, *names):
    for n in names:
        if hasattr(mod, n):
            return getattr(mod, n)
    raise AttributeError(f"None of {names} found in {mod.__name__}")

predict_video = get_fn(app, "predict_video", "predict_vid")
predict_image_audio = get_fn(app, "predict_image_audio")

def test_predict_image_audio_fuses_correctly(monkeypatch):
    K = len(app.lables)
    # deterministic distributions
    p_img = np.zeros(K); p_img[0] = 1.0     # image votes class 0
    p_aud = np.zeros(K); p_aud[1] = 1.0     # audio votes class 1

    # Monkeypatch heavy parts to lightweight stubs
    monkeypatch.setattr(app, "clip_image_probs", lambda img, **kw: p_img, raising=True)
   
    if hasattr(app, "wav2vec2_zero_shot_probs"):
        monkeypatch.setattr(app, "wav2vec2_zero_shot_probs", lambda wave, **kw: p_aud, raising=True)
    if hasattr(app, "audio_prior_from_rms"):
        monkeypatch.setattr(app, "audio_prior_from_rms", lambda rms: p_aud, raising=False)
    if hasattr(app, "wav2vec2_embed_energy"):
        monkeypatch.setattr(app, "wav2vec2_embed_energy", lambda wave: (np.zeros(768, dtype=np.float32), 0.5), raising=True)
    monkeypatch.setattr(app, "load_audio_16k", lambda path: np.zeros(16000, dtype=np.float32), raising=True)
    monkeypatch.setattr(app, "log_inference", lambda **kw: None, raising=False)  # no file writes

    # dummy inputs
    dummy_img = Image.new("RGB", (64, 64), color=128)
    dummy_audio_path = "dummy.wav"

   
    pred_hi, probs_hi, _ = predict_image_audio(dummy_img, dummy_audio_path, 0.9)
   
    pred_lo, probs_lo, _ = predict_image_audio(dummy_img, dummy_audio_path, 0.1)

    # Indexes for label 0/1
    idx0 = 0
    idx1 = 1

    assert probs_hi[app.lables[idx0]] > probs_hi[app.lables[idx1]]
    assert probs_lo[app.lables[idx1]] > probs_lo[app.lables[idx0]]
    assert 0.99 <= sum(probs_hi.values()) <= 1.01
    assert 0.99 <= sum(probs_lo.values()) <= 1.01
    assert isinstance(pred_hi, str) and isinstance(pred_lo, str)

def test_predict_video_path_runs_with_stubs(monkeypatch):
    K = len(app.lables)
    p_img = np.zeros(K); p_img[-1] = 1.0   # image favors last class
    p_aud = np.zeros(K); p_aud[0]  = 1.0   # audio favors first class

    # Stub frame extractor to avoid ffmpeg
    frames = [Image.new("RGB", (32, 32), c) for c in [(255,0,0)]*5]
    wave = np.zeros(16000, dtype=np.float32)
    meta = {"n_frames":5, "fps_used":1.0, "duration_s":5.0}
    # Mock the video_to_frame_audio function that's imported from utils_media
    monkeypatch.setattr(app, "video_to_frame_audio", lambda v, **kw: (frames, wave, meta), raising=True)

    # Stub models
    monkeypatch.setattr(app, "clip_image_probs", lambda pil, **kw: p_img, raising=True)
    if hasattr(app, "wav2vec2_zero_shot_probs"):
        monkeypatch.setattr(app, "wav2vec2_zero_shot_probs", lambda w, **kw: p_aud, raising=True)
    if hasattr(app, "wav2vec2_embed_energy"):
        monkeypatch.setattr(app, "wav2vec2_embed_energy", lambda w: (np.zeros(768, dtype=np.float32), 0.3), raising=True)
    monkeypatch.setattr(app, "log_inference", lambda **kw: None, raising=False)

    # Call
    pred, probs, lat = predict_video("dummy.mp4", 0.5)

    # Checks
    assert isinstance(pred, str)
    assert set(probs.keys()) == set(app.lables)
    assert "t_total_ms" in lat and "n_frames" in lat or "t_total_ms" in lat