File size: 4,472 Bytes
3f40093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d53b116
17e2272
 
 
 
d53b116
17e2272
 
 
d53b116
 
 
17e2272
 
 
 
 
d53b116
3f40093
c0e9503
3f40093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn
import math

from typing import List, Tuple, Union

from .layers import mlp

SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]


def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
    """
    Applies Fourier feature mapping to input tensor x using frequency matrix w. This
    projects inputs through sinusoidal functions to create higher dimensional features
    that help mitigate spectral bias - the tendency of neural networks to learn
    low-frequency functions more easily than high-frequency ones. By explicitly
    mapping inputs to higher frequencies through sin/cos transformations, we enable
    better learning of fine details and higher frequency patterns.

    Args:
        x: Input tensor to transform
        w: Matrix of frequencies for the Fourier features transformation

    Returns:
        Concatenated cosine and sine transformed features as a tensor
    """
    f = 2 * math.pi * x @ w
    return torch.cat([f.cos(), f.sin()], dim=-1)


def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes as input a tensor containing a single float coordinate value (x or y)
    and encodes it into hidden states for input to the text model.

    Args:
        coord: Tensor with single float coordinate value

    Returns:
        Encoded hidden states tensor for input to text model
    """
    return w.coord_encoder(fourier_features(coord, w.coord_features))


def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes as input the last hidden state from the text model and outputs a single logit
    representing either an x or y coordinate prediction.

    Args:
        hidden_state: The final hidden state tensor from the text model.

    Returns:
        A single logit representing the predicted coordinate value (x or y)
    """
    return mlp(hidden_state, w.coord_decoder)


def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes a tensor containing width and height values and encodes them into
    hidden states for input to the text model.

    Args:
        size: Tensor with two floats for width and height

    Returns:
        Encoded hidden states tensor for input to text model
    """
    return w.size_encoder(fourier_features(size, w.size_features))


def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes as input the last hidden state from the text model and outputs logits
    for 1024 bins representing width and height in log-scale.

    Returns logits shaped (..., 2, C) so batched code can handle it directly.
    """
    # Run the two-layer MLP that projects to 2*C (width+height) bins
    x = mlp(hidden_state, w.size_decoder)               # shape: (..., 2*C)

    last = x.shape[-1]
    if last % 2 != 0:
        raise RuntimeError(f"size_out_dim must be even, got {last}")

    C = last // 2
    # Keep any leading batch/seq dims intact and split the last dim into (2, C)
    return x.view(*x.shape[:-1], 2, C)






def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
    """
    Takes a list of spatial references (points or regions) and encodes them into
    hidden states for input to the text model.

    Args:
        spatial_refs: List of spatial references (points or boxes)
            - Points are represented as normalized (x, y) tuples
            - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples

    Returns:
        {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]}
    """
    coords, sizes = [], []
    for ref in spatial_refs:
        if len(ref) == 2:
            coords.append(ref[0])
            coords.append(ref[1])
        else:
            x_c = (ref[0] + ref[2]) / 2
            y_c = (ref[1] + ref[3]) / 2
            width = ref[2] - ref[0]
            height = ref[3] - ref[1]
            coords.append(x_c)
            coords.append(y_c)
            sizes.append([width, height])

    coords = torch.tensor(
        coords, device=w.coord_features.device, dtype=w.coord_features.dtype
    ).view(-1, 1)
    coords = encode_coordinate(coords, w)

    if sizes:
        sizes = torch.tensor(
            sizes, device=w.size_features.device, dtype=w.size_features.dtype
        )
        sizes = encode_size(sizes, w)
    else:
        sizes = None

    return {"coords": coords, "sizes": sizes}