File size: 5,147 Bytes
549c270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# src/models/fusion.py
from __future__ import annotations

import numpy as np
from typing import Optional, Sequence, Tuple

__all__ = ["l2norm", "concat_fusion", "weighted_sum_fusion"]


def _as_float32(x: np.ndarray | Sequence) -> np.ndarray:
    """Cast to contiguous float32 ndarray (2D if possible)."""
    arr = np.asarray(x)
    if arr.ndim == 1:
        arr = arr.reshape(-1, arr.shape[-1])
    return np.ascontiguousarray(arr, dtype=np.float32)


def l2norm(x: np.ndarray, eps: float = 1e-9) -> np.ndarray:
    """
    L2-normalize a 2D array [N, D] along the last dimension.

    Parameters
    ----------
    x : np.ndarray
        Array of shape [N, D].
    eps : float
        Small number to avoid division by zero.

    Returns
    -------
    np.ndarray
        L2-normalized array (float32) of shape [N, D].
    """
    x = _as_float32(x)
    norms = np.linalg.norm(x, axis=1, keepdims=True) + eps
    return x / norms


def _gather_modalities(
    Vt: Optional[np.ndarray] = None,
    Vi: Optional[np.ndarray] = None,
    Vm: Optional[np.ndarray] = None,
) -> Tuple[list[np.ndarray], list[str]]:
    """Collect non-None modality matrices in consistent order."""
    blocks, names = [], []
    if Vt is not None:
        blocks.append(_as_float32(Vt)); names.append("text")
    if Vi is not None:
        blocks.append(_as_float32(Vi)); names.append("image")
    if Vm is not None:
        blocks.append(_as_float32(Vm)); names.append("meta")
    if not blocks:
        raise ValueError("At least one modality (Vt/Vi/Vm) must be provided.")
    # Sanity: same item count across blocks
    n = {b.shape[0] for b in blocks}
    if len(n) != 1:
        raise ValueError(f"All modalities must have same N items, got Ns={n}.")
    return blocks, names


def concat_fusion(
    Vt: Optional[np.ndarray] = None,
    Vi: Optional[np.ndarray] = None,
    Vm: Optional[np.ndarray] = None,
    weights: Optional[Sequence[float]] = None,
) -> np.ndarray:
    """
    Concatenate modality vectors (optionally weighted per block) and L2-normalize.

    Parameters
    ----------
    Vt, Vi, Vm : np.ndarray or None
        Matrices of shape [I, Dt], [I, Di], [I, Dm].
    weights : sequence of floats or None
        Optional per-modality weights in the order (text, image, meta).
        Missing modalities are skipped; when None, all used modalities get weight=1.

    Returns
    -------
    np.ndarray
        Fused item matrix of shape [I, Dt+Di+Dm_used], L2-normalized.
    """
    blocks, names = _gather_modalities(Vt, Vi, Vm)

    # Build weights aligned to available modalities
    if weights is None:
        w = [1.0] * len(blocks)
    else:
        # Map full tuple -> subset by names present
        name_order = ["text", "image", "meta"]
        full = dict(zip(name_order, list(weights) + [1.0] * (3 - len(weights))))
        w = [float(full[n]) for n in names]

    # Apply weights per block
    weighted = [b * w_i for b, w_i in zip(blocks, w)]

    fused = np.concatenate(weighted, axis=1)
    return l2norm(fused)


def _pad_to_dim(x: np.ndarray, target_dim: int) -> np.ndarray:
    """
    Zero-pad vectors to target_dim along the last axis.

    Parameters
    ----------
    x : [N, D]
    target_dim : int

    Returns
    -------
    [N, target_dim]
    """
    x = _as_float32(x)
    N, D = x.shape
    if D == target_dim:
        return x
    if D > target_dim:
        # If any block is larger than target, bump target up
        raise ValueError(f"Block has dim {D} > target_dim {target_dim}.")
    out = np.zeros((N, target_dim), dtype=np.float32)
    out[:, :D] = x
    return out


def weighted_sum_fusion(
    Vt: Optional[np.ndarray] = None,
    Vi: Optional[np.ndarray] = None,
    Vm: Optional[np.ndarray] = None,
    weights: Optional[Sequence[float]] = None,
) -> np.ndarray:
    """
    Weighted sum fusion across modalities with automatic zero-padding to a common dim.

    Notes
    -----
    - This keeps a simple, dependency-free baseline. We align dimensions by zero-padding
      smaller blocks to the maximum dimensionality across provided modalities, then sum.
    - Returned vectors are L2-normalized.

    Parameters
    ----------
    Vt, Vi, Vm : np.ndarray or None
        Matrices of shape [I, Dt], [I, Di], [I, Dm].
    weights : sequence of floats or None
        Per-modality weights in order (text, image, meta). If None → weights=1.

    Returns
    -------
    np.ndarray
        Fused item matrix of shape [I, D_max], L2-normalized.
    """
    blocks, names = _gather_modalities(Vt, Vi, Vm)

    if weights is None:
        w = [1.0] * len(blocks)
    else:
        name_order = ["text", "image", "meta"]
        full = dict(zip(name_order, list(weights) + [1.0] * (3 - len(weights))))
        w = [float(full[n]) for n in names]

    # Determine target dimension
    dims = [b.shape[1] for b in blocks]
    D_max = max(dims)

    # Pad, weight, and sum
    acc = np.zeros((blocks[0].shape[0], D_max), dtype=np.float32)
    for b, w_i in zip(blocks, w):
        acc += _pad_to_dim(b, D_max) * w_i

    return l2norm(acc)

__all__ = ["l2norm", "concat_fusion", "weighted_sum_fusion"]