Fred808 commited on
Commit
554d262
·
verified ·
1 Parent(s): 94a6cd4

Update tensor_server.py

Browse files
Files changed (1) hide show
  1. tensor_server.py +53 -22
tensor_server.py CHANGED
@@ -144,32 +144,63 @@ def load_chunk(chunk: ModelChunk) -> torch.nn.Module:
144
  raise ValueError(f"Chunk file not found: {chunk_file}")
145
 
146
  # For raw binary chunks, we'll create a simple buffer module
147
- class ChunkBuffer(torch.nn.Module):
148
- def __init__(self, chunk_path: str, config: Dict):
 
 
 
 
 
149
  super().__init__()
 
 
 
 
 
 
 
150
  self.chunk_path = chunk_path
151
- self.config = config
152
- self.start_offset = config.get('start_offset', 0)
153
- self.size = config.get('size_bytes', 0)
154
- # expose vocab_offset on the module for aggregator use
155
- self.vocab_offset = int(config.get('vocab_offset', 0))
156
-
 
 
 
 
 
 
157
  def forward(self, x: torch.Tensor) -> torch.Tensor:
158
- # In a real implementation, this would process the input
159
- # using the chunk data. For now, we'll just return the input
160
- # as this is just for testing the distribution system
161
- return x
162
-
163
- # Create and return the chunk buffer
164
- chunk_model = ChunkBuffer(chunk_file, chunk_config)
165
- # Ensure the chunk_model.config is the up-to-date config (including any assigned offsets)
166
- chunk_model.config = chunk_config
167
- print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}")
168
-
169
- return chunk_model
170
 
171
- except Exception as e:
172
- raise Exception(f"Failed to load chunk: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  async def process_tensor(chunk_id: int, inputs: torch.Tensor) -> torch.Tensor:
175
  """Process input tensor through the specified chunk"""
 
144
  raise ValueError(f"Chunk file not found: {chunk_file}")
145
 
146
  # For raw binary chunks, we'll create a simple buffer module
147
+ class ChunkBuffer(nn.Module):
148
+ """
149
+ A single Florence-2 caption chunk that receives pre-encoded image embeddings
150
+ and produces partial vocabulary logits.
151
+ """
152
+
153
+ def __init__(self, chunk_path: str, config: dict):
154
  super().__init__()
155
+
156
+ # Get dimensions from config
157
+ input_dim = config.get("input_dim", 1024) # Florence-2 embedding dim
158
+ output_dim = config.get("output_dim", 1000) # size of vocab shard
159
+ dropout = config.get("dropout", 0.1)
160
+
161
+ # Optional: chunk_path can point to pretrained weights
162
  self.chunk_path = chunk_path
163
+
164
+ # Main projection layer: embedding → partial vocab logits
165
+ self.linear = nn.Linear(input_dim, output_dim)
166
+
167
+ # Optional normalization + dropout (stabilizes training or inference variance)
168
+ self.norm = nn.LayerNorm(input_dim)
169
+ self.dropout = nn.Dropout(dropout)
170
+
171
+ # Initialize weights (small variance, stable logits)
172
+ nn.init.xavier_uniform_(self.linear.weight)
173
+ nn.init.zeros_(self.linear.bias)
174
+
175
  def forward(self, x: torch.Tensor) -> torch.Tensor:
176
+ """
177
+ Args:
178
+ x: Florence-2 image embedding tensor, shape [batch, 1024]
179
+ Returns:
180
+ logits for this vocab shard, shape [batch, output_dim]
181
+ """
182
+ # Normalize + dropout
183
+ x = self.norm(x)
184
+ x = self.dropout(x)
 
 
 
185
 
186
+ # Linear projection to vocab slice
187
+ logits = self.linear(x)
188
+
189
+ # (Optional) softmax for probabilities, but usually the main model handles this
190
+ # probs = F.softmax(logits, dim=-1)
191
+ return logits
192
+
193
+
194
+ # Create and return the chunk buffer
195
+ chunk_model = ChunkBuffer(chunk_file, chunk_config)
196
+ # Ensure the chunk_model.config is the up-to-date config (including any assigned offsets)
197
+ chunk_model.config = chunk_config
198
+ print(f"[INFO] Loaded chunk {chunk.chunk_id} ({chunk_config.get('size_bytes', 0)} bytes) from {chunk.files[0]}")
199
+
200
+ return chunk_model
201
+
202
+ except Exception as e:
203
+ raise Exception(f"Failed to load chunk: {str(e)}")
204
 
205
  async def process_tensor(chunk_id: int, inputs: torch.Tensor) -> torch.Tensor:
206
  """Process input tensor through the specified chunk"""