righthook75 commited on
Commit
2e11098
·
verified ·
1 Parent(s): fd28dbf

Upload training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training.py +164 -0
training.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import zipfile
4
+ import tempfile
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from torch.utils.data import Dataset
10
+ from PIL import Image
11
+
12
+
13
+ class SAM3FineTuneDataset(Dataset):
14
+ """Dataset for fine-tuning SAM3 on accepted detections.
15
+
16
+ Each sample corresponds to one accepted detection. The processor is called
17
+ with the detection's bounding box as a box prompt, and the ground-truth
18
+ mask is returned as a binary tensor.
19
+ """
20
+
21
+ def __init__(self, images_dict, detections, processor):
22
+ """
23
+ Args:
24
+ images_dict: dict mapping filename -> PIL.Image
25
+ detections: list of detection dicts with keys:
26
+ image_path, box, mask (numpy H×W bool/uint8)
27
+ processor: Sam3Processor instance
28
+ """
29
+ self.images_dict = images_dict
30
+ self.detections = detections
31
+ self.processor = processor
32
+
33
+ def __len__(self):
34
+ return len(self.detections)
35
+
36
+ def __getitem__(self, idx):
37
+ det = self.detections[idx]
38
+ image = self.images_dict[det["image_path"]]
39
+ box = det["box"] # [x1, y1, x2, y2]
40
+
41
+ inputs = self.processor(
42
+ images=image,
43
+ input_boxes=[[box]],
44
+ input_boxes_labels=[[1]],
45
+ return_tensors="pt",
46
+ )
47
+ # Keep batch dim from processor (batch-size-1 training)
48
+
49
+ mask_np = det["mask"].astype(np.float32)
50
+ mask_gt = torch.from_numpy(mask_np).unsqueeze(0) # (1, H, W)
51
+
52
+ return inputs, mask_gt
53
+
54
+
55
+ def freeze_encoder(model):
56
+ """Freeze vision encoder, keep mask decoder and prompt encoder trainable.
57
+
58
+ Returns (trainable_count, total_count).
59
+ """
60
+ for param in model.vision_encoder.parameters():
61
+ param.requires_grad = False
62
+
63
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
64
+ total = sum(p.numel() for p in model.parameters())
65
+ return trainable, total
66
+
67
+
68
+ def _dice_loss(pred, target):
69
+ """Compute Dice loss between predicted and target masks."""
70
+ pred_flat = pred.flatten(1)
71
+ target_flat = target.flatten(1)
72
+ intersection = (pred_flat * target_flat).sum(1)
73
+ return 1 - (2.0 * intersection + 1) / (pred_flat.sum(1) + target_flat.sum(1) + 1)
74
+
75
+
76
+ def run_training(model, processor, dataset, epochs, learning_rate, progress_callback=None):
77
+ """Fine-tune SAM3 mask decoder + prompt encoder.
78
+
79
+ Args:
80
+ model: Sam3Model with encoder frozen
81
+ processor: Sam3Processor
82
+ dataset: SAM3FineTuneDataset
83
+ epochs: number of training epochs
84
+ learning_rate: AdamW learning rate
85
+ progress_callback: callable(epoch, step, total_steps, loss_val)
86
+
87
+ Returns dict with keys: model, loss_history (list of avg loss per epoch).
88
+ """
89
+ device = next(model.parameters()).device
90
+ optimizer = torch.optim.AdamW(
91
+ filter(lambda p: p.requires_grad, model.parameters()),
92
+ lr=learning_rate,
93
+ )
94
+
95
+ total_steps = len(dataset) * epochs
96
+ loss_history = []
97
+
98
+ model.train()
99
+ for epoch in range(epochs):
100
+ epoch_losses = []
101
+ for step_in_epoch in range(len(dataset)):
102
+ global_step = epoch * len(dataset) + step_in_epoch
103
+
104
+ inputs, mask_gt = dataset[step_in_epoch]
105
+ # Move inputs to device
106
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
107
+ mask_gt = mask_gt.to(device)
108
+
109
+ outputs = model(**inputs)
110
+
111
+ # Get predicted masks — shape (batch, num_masks, H, W)
112
+ pred_masks = outputs.pred_masks
113
+ if pred_masks.dim() == 4:
114
+ pred_masks = pred_masks.squeeze(0) # (num_masks, H, W)
115
+
116
+ # Resize prediction to match ground truth
117
+ if pred_masks.shape[-2:] != mask_gt.shape[-2:]:
118
+ pred_masks = F.interpolate(
119
+ pred_masks.unsqueeze(0),
120
+ size=mask_gt.shape[-2:],
121
+ mode="bilinear",
122
+ align_corners=False,
123
+ ).squeeze(0)
124
+
125
+ # Use first predicted mask
126
+ pred = pred_masks[0:1] # (1, H, W)
127
+ pred_sigmoid = torch.sigmoid(pred)
128
+
129
+ bce = F.binary_cross_entropy_with_logits(pred, mask_gt)
130
+ dice = _dice_loss(pred_sigmoid, mask_gt).mean()
131
+ loss = bce + dice
132
+
133
+ optimizer.zero_grad()
134
+ loss.backward()
135
+ optimizer.step()
136
+
137
+ loss_val = loss.item()
138
+ epoch_losses.append(loss_val)
139
+
140
+ if progress_callback:
141
+ progress_callback(epoch, global_step, total_steps, loss_val)
142
+
143
+ loss_history.append(sum(epoch_losses) / len(epoch_losses))
144
+
145
+ model.eval()
146
+ return {"model": model, "loss_history": loss_history}
147
+
148
+
149
+ def get_model_zip_bytes(model, processor):
150
+ """Save fine-tuned model and processor to a zip file, return bytes."""
151
+ with tempfile.TemporaryDirectory() as tmpdir:
152
+ model_dir = os.path.join(tmpdir, "sam3_finetuned")
153
+ model.save_pretrained(model_dir)
154
+ processor.save_pretrained(model_dir)
155
+
156
+ buf = io.BytesIO()
157
+ with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
158
+ for root, _dirs, files in os.walk(model_dir):
159
+ for fname in files:
160
+ filepath = os.path.join(root, fname)
161
+ arcname = os.path.relpath(filepath, tmpdir)
162
+ zf.write(filepath, arcname)
163
+ buf.seek(0)
164
+ return buf.getvalue()