Fred808 commited on
Commit
9cf6256
·
verified ·
1 Parent(s): acc5466

Update tensor_server.py

Browse files
Files changed (1) hide show
  1. tensor_server.py +42 -12
tensor_server.py CHANGED
@@ -144,21 +144,51 @@ def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
144
  raise ValueError(f"Chunk file not found: {chunk_file}")
145
 
146
  # For raw binary chunks, we'll create a simple buffer module
147
- class ChunkBuffer(torch.nn.Module):
148
- def __init__(self, chunk_path: str, config: Dict):
 
 
 
 
 
149
  super().__init__()
 
 
 
 
 
 
 
150
  self.chunk_path = chunk_path
151
- self.config = config
152
- self.start_offset = config.get('start_offset', 0)
153
- self.size = config.get('size_bytes', 0)
154
- # expose vocab_offset on the module for aggregator use
155
- self.vocab_offset = int(config.get('vocab_offset', 0))
156
-
 
 
 
 
 
 
157
  def forward(self, x: torch.Tensor) -> torch.Tensor:
158
- # In a real implementation, this would process the input
159
- # using the chunk data. For now, we'll just return the input
160
- # as this is just for testing the distribution system
161
- return x
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  # Create and return the chunk buffer
164
  chunk_model = ChunkBuffer(chunk_file, chunk_config)
 
144
  raise ValueError(f"Chunk file not found: {chunk_file}")
145
 
146
  # For raw binary chunks, we'll create a simple buffer module
147
+ class ChunkBuffer(nn.Module):
148
+ """
149
+ A single Florence-2 caption chunk that receives pre-encoded image embeddings
150
+ and produces partial vocabulary logits.
151
+ """
152
+
153
+ def __init__(self, chunk_path: str, config: dict):
154
  super().__init__()
155
+
156
+ # Get dimensions from config
157
+ input_dim = config.get("input_dim", 1024) # Florence-2 embedding dim
158
+ output_dim = config.get("output_dim", 1000) # size of vocab shard
159
+ dropout = config.get("dropout", 0.1)
160
+
161
+ # Optional: chunk_path can point to pretrained weights
162
  self.chunk_path = chunk_path
163
+
164
+ # Main projection layer: embedding → partial vocab logits
165
+ self.linear = nn.Linear(input_dim, output_dim)
166
+
167
+ # Optional normalization + dropout (stabilizes training or inference variance)
168
+ self.norm = nn.LayerNorm(input_dim)
169
+ self.dropout = nn.Dropout(dropout)
170
+
171
+ # Initialize weights (small variance, stable logits)
172
+ nn.init.xavier_uniform_(self.linear.weight)
173
+ nn.init.zeros_(self.linear.bias)
174
+
175
  def forward(self, x: torch.Tensor) -> torch.Tensor:
176
+ """
177
+ Args:
178
+ x: Florence-2 image embedding tensor, shape [batch, 1024]
179
+ Returns:
180
+ logits for this vocab shard, shape [batch, output_dim]
181
+ """
182
+ # Normalize + dropout
183
+ x = self.norm(x)
184
+ x = self.dropout(x)
185
+
186
+ # Linear projection to vocab slice
187
+ logits = self.linear(x)
188
+
189
+ # (Optional) softmax for probabilities, but usually the main model handles this
190
+ # probs = F.softmax(logits, dim=-1)
191
+ return logits
192
 
193
  # Create and return the chunk buffer
194
  chunk_model = ChunkBuffer(chunk_file, chunk_config)