AbstractPhil commited on
Commit
b29898a
·
verified ·
1 Parent(s): ae415e4

Create REAME.md

Browse files
Files changed (1) hide show
  1. REAME.md +118 -0
REAME.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The 155000 step version has about 158,100,000 prompt samples weight trained using the
2
+
3
+ AbstractPhil/T5-Small-Human-Attentive-Try2-Pass3
4
+
5
+ This T5-small model is fried to echo and interpolate math in complex intended ways. I haven't given it the full robust check yet, but it's definitely pretty fed.
6
+
7
+ This adapter here is trained using T5 inputs with
8
+
9
+
10
+ ```
11
+ def main():
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # HF Hub settings
15
+ hf_repo_id = "AbstractPhil/t5-to-vit-l-14-velocity-adapter-v3-100m-77tok"
16
+ push_every_n_steps = 5000
17
+
18
+ # Tokenizers & frozen models
19
+ t5_tok = T5TokenizerFast.from_pretrained("t5-small")
20
+ t5_mod = T5EncoderModel.from_pretrained(
21
+ "AbstractPhil/T5-Small-Human-Attentive-Try2-Pass3"
22
+ ).to(device).eval()
23
+ clip_tok = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
24
+ clip_mod = CLIPTextModel.from_pretrained(
25
+ "openai/clip-vit-large-patch14"
26
+ ).to(device).eval()
27
+
28
+ # Adapter & optimizer
29
+ adapter = RobustVelocityAdapter(out_tokens=77).to(device)
30
+ optimizer = optim.AdamW(adapter.parameters(), lr=5e-4)
31
+
32
+ # Compile models for speed
33
+ t5_mod = torch.compile(t5_mod)
34
+ clip_mod = torch.compile(clip_mod)
35
+ adapter = torch.compile(adapter)
36
+
37
+ scaler = GradScaler() # for mixed precision
38
+
39
+ # Data
40
+ dataset = ParsedMultiCharDataset("AbstractPhil/human-templated-captions-1b",
41
+ num_files=12)
42
+ loader = DataLoader(dataset,
43
+ batch_size=None,
44
+ num_workers=4,
45
+ pin_memory=True)
46
+ iterator = iter(loader)
47
+
48
+ batch_size = 256
49
+ accum_steps = 4 # effective BS = 256 * 4 = 1024
50
+ max_steps = math.ceil(dataset.total_rows / batch_size)
51
+ pbar = tqdm(total=max_steps, desc="Adapter training")
52
+
53
+ for step in range(1, max_steps+1):
54
+ # zero grads on actual step
55
+ if (step-1) % accum_steps == 0:
56
+ optimizer.zero_grad()
57
+
58
+ # 1) Collect batch
59
+ texts = []
60
+ for _ in range(batch_size):
61
+ try:
62
+ _, txt = next(iterator)
63
+ except StopIteration:
64
+ iterator = iter(loader)
65
+ _, txt = next(iterator)
66
+ texts.append(txt)
67
+
68
+ # 2) Tokenize
69
+ t5_inputs = t5_tok(texts,
70
+ padding=True,
71
+ truncation=True,
72
+ max_length=77,
73
+ return_tensors="pt").to(device)
74
+ clip_inputs = clip_tok(texts,
75
+ padding="max_length",
76
+ truncation=True,
77
+ max_length=77,
78
+ return_tensors="pt").to(device)
79
+
80
+ # 3) Forward + loss in mixed precision
81
+ with autocast():
82
+ t5_seq = t5_mod(**t5_inputs).last_hidden_state # [B,64,512]
83
+ clip_seq = clip_mod(**clip_inputs).last_hidden_state # [B,64,768]
84
+
85
+ anchor_pred, delta_pred, sigma_pred = adapter(t5_seq)
86
+ delta_target = clip_seq - anchor_pred
87
+
88
+ loss_delta = hetero_loss(delta_pred, delta_target, sigma_pred)
89
+ # cosine anchor alignment
90
+ cos_sim = nn.functional.cosine_similarity(
91
+ anchor_pred.reshape(-1,768),
92
+ clip_seq.reshape(-1,768),
93
+ dim=-1
94
+ ).mean()
95
+ loss_anchor = (1 - cos_sim) * 0.1
96
+
97
+ loss = loss_delta + loss_anchor
98
+ loss = loss / accum_steps # scale for accumulation
99
+
100
+ # 4) Backward + optimizer step
101
+ scaler.scale(loss).backward()
102
+ if step % accum_steps == 0:
103
+ scaler.unscale_(optimizer)
104
+ torch.nn.utils.clip_grad_norm_(adapter.parameters(), 1.0)
105
+ scaler.step(optimizer)
106
+ scaler.update()
107
+
108
+ pbar.update(1)
109
+ pbar.set_postfix(loss=(loss.item() * accum_steps))
110
+
111
+ # 5) Save & push every N steps
112
+ if step % push_every_n_steps == 0:
113
+ ckpt = f"/content/drive/MyDrive/t5-adapter/t5-to-vit-l-14-velocity-adapter-v3-100m-77tok_step_{step}.safetensors"
114
+ save_file(adapter.state_dict(), ckpt)
115
+ #upload_file(ckpt, ckpt, repo_id=hf_repo_id)
116
+
117
+
118
+ pbar.close()