HV-Khurdula commited on
Commit
7eac0da
·
verified ·
1 Parent(s): 23f9bc4

Update region.py

Browse files

fix: return shape for decode.

Files changed (1) hide show
  1. region.py +3 -11
region.py CHANGED
@@ -71,19 +71,11 @@ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
71
  return w.size_encoder(fourier_features(size, w.size_features))
72
 
73
 
74
- # region.py
75
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
76
- """
77
- Returns logits for width & height bins without collapsing batch/seq dims.
 
78
 
79
- Input (hidden_state): (..., C)
80
- Output: (..., 2, bins) # keeps all leading dims intact
81
- """
82
- x = mlp(hidden_state, w.size_decoder) # (..., size_out_dim)
83
- last = x.shape[-1]
84
- if last % 2 != 0:
85
- raise RuntimeError(f"size_out_dim must be even, got {last}")
86
- return x.view(*x.shape[:-1], 2, last // 2) # (..., 2, bins)
87
 
88
 
89
 
 
71
  return w.size_encoder(fourier_features(size, w.size_features))
72
 
73
 
 
74
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
75
+ # Original API expected by moondream.py: shape (2, C) when called on the last hidden state
76
+ x = mlp(hidden_state, w.size_decoder) # (..., 2*C)
77
+ return x.view(2, -1)
78
 
 
 
 
 
 
 
 
 
79
 
80
 
81