openagi-agi commited on
Commit
7cd7caf
·
verified ·
1 Parent(s): dd914f5

Upload 8 files

Browse files
Files changed (8) hide show
  1. decoder.py +103 -0
  2. formatter.py +77 -0
  3. main.py +79 -0
  4. pre_embed.py +55 -0
  5. test_chat.py +91 -0
  6. test_embed.py +18 -0
  7. test_full.py +203 -0
  8. testcuda.py +6 -0
decoder.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import pandas as pd
5
+ from transformers import T5Tokenizer
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ # ===== CONFIG =====
9
+ INPUT_FILE = "chat_1turn.csv"
10
+ EMB_FILE = "chat_embeddings.pt"
11
+ MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
12
+ EPOCHS = 80
13
+ BATCH_SIZE = 16
14
+ HIDDEN_DIM = 512
15
+ MAX_LEN = 64
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ print(f"Using device: {device}")
19
+
20
+ # ===== Load CSV =====
21
+ df = pd.read_csv(INPUT_FILE)
22
+ sources = df["source"].fillna("").tolist()
23
+ targets = df["target"].fillna("").tolist()
24
+
25
+ # ===== Tokenizer =====
26
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
27
+ target_enc = tokenizer(targets, padding=True, truncation=True,
28
+ return_tensors="pt", max_length=MAX_LEN)
29
+ input_ids = target_enc["input_ids"].to(device)
30
+ attention_mask = target_enc["attention_mask"].to(device)
31
+
32
+ # ===== Load embeddings =====
33
+ emb_data = torch.load(EMB_FILE)
34
+ x_embeddings = emb_data["source"].to(device) # not used directly in this training
35
+ y_embeddings = emb_data["target"].to(device) # used to condition decoder
36
+
37
+ # ===== Decoder =====
38
+ class EmbeddingDecoder(nn.Module):
39
+ def __init__(self, input_dim, hidden_dim, vocab_size):
40
+ super().__init__()
41
+ self.bridge = nn.Linear(input_dim, hidden_dim)
42
+ self.embed = nn.Embedding(vocab_size, hidden_dim)
43
+ self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
44
+ self.fc = nn.Linear(hidden_dim, vocab_size)
45
+
46
+ def forward(self, emb_vec, target_ids=None, teacher_forcing_ratio=0.5, max_len=MAX_LEN):
47
+ hidden = self.bridge(emb_vec).unsqueeze(0) # [1,B,H]
48
+ B = emb_vec.size(0)
49
+ outputs = []
50
+
51
+ # start with pad_token (T5 has pad=0, eos=1)
52
+ inp = torch.full((B,1), tokenizer.pad_token_id, device=emb_vec.device)
53
+
54
+ for t in range(max_len):
55
+ inp_emb = self.embed(inp) # [B,1,H]
56
+ out, hidden = self.gru(inp_emb, hidden) # [B,1,H]
57
+ logits = self.fc(out.squeeze(1)) # [B,V]
58
+ outputs.append(logits.unsqueeze(1))
59
+
60
+ if target_ids is not None and t < target_ids.size(1) and torch.rand(1).item() < teacher_forcing_ratio:
61
+ inp = target_ids[:, t].unsqueeze(1)
62
+ else:
63
+ inp = torch.argmax(logits, dim=-1, keepdim=True)
64
+
65
+ return torch.cat(outputs, dim=1) # [B, max_len, V]
66
+
67
+ # ===== Train =====
68
+ decoder = EmbeddingDecoder(y_embeddings.shape[1], HIDDEN_DIM, tokenizer.vocab_size).to(device)
69
+ optimizer = optim.Adam(decoder.parameters(), lr=1e-3)
70
+ criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
71
+
72
+ print("Training decoder...")
73
+ for epoch in range(EPOCHS):
74
+ decoder.train()
75
+ total_loss = 0.0
76
+ for i in range(0, len(y_embeddings), BATCH_SIZE):
77
+ xb = y_embeddings[i:i+BATCH_SIZE]
78
+ yb = input_ids[i:i+BATCH_SIZE]
79
+
80
+ optimizer.zero_grad()
81
+ logits = decoder(xb, target_ids=yb, teacher_forcing_ratio=0.7, max_len=yb.size(1))
82
+ loss = criterion(logits.reshape(-1, logits.size(-1)), yb.reshape(-1))
83
+ loss.backward()
84
+ optimizer.step()
85
+ total_loss += loss.item()
86
+
87
+ print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}")
88
+
89
+ # ===== Inference =====
90
+ embedder = SentenceTransformer(MODEL_NAME, device=device)
91
+
92
+ def generate(text, max_len=30, use_mapper=False, mapper=None):
93
+ with torch.no_grad():
94
+ # embed new text
95
+ emb = embedder.encode([text], convert_to_tensor=True, device=device)
96
+ if use_mapper and mapper is not None:
97
+ emb = mapper(emb)
98
+ logits = decoder(emb, target_ids=None, teacher_forcing_ratio=0.0, max_len=max_len)
99
+ ids = torch.argmax(logits, dim=-1).squeeze(0).tolist()
100
+ return tokenizer.decode(ids, skip_special_tokens=True)
101
+
102
+ # ===== Test =====
103
+ print("Hi ->", generate("Hi"))
formatter.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import T5Tokenizer
3
+ import pandas as pd, csv, re
4
+ from tqdm import tqdm
5
+
6
+ # ── Config ────────────────────────────────────────────────────────────────
7
+ jsonl_path = "lmsys_chat_1m_full.jsonl" # local file
8
+ use_subset = False # False ⇒ full 1 M rows
9
+ num_samples = 500 # if subset
10
+ max_turn_pairs = 1 # 4 user+assistant = 8 lines
11
+ max_input_tokens = 512 # fits t5-small/base
12
+ # ──────────────────────────────────────────────────────────────────────────
13
+
14
+ tok = T5Tokenizer.from_pretrained("t5-small")
15
+ ds = load_dataset("json", data_files=jsonl_path, split="train")
16
+
17
+ if use_subset:
18
+ ds = ds.select(range(min(num_samples, len(ds))))
19
+ print(f"🔍 subset → {len(ds)} rows")
20
+
21
+ def mostly_ascii(s: str, threshold: float = .3) -> bool:
22
+ try:
23
+ return sum(ord(ch) > 127 for ch in s) / len(s) < threshold
24
+ except ZeroDivisionError:
25
+ return False
26
+
27
+ def format_turns(conv):
28
+ return [f"{m['role'].capitalize()}: {m['content'].strip()}" for m in conv]
29
+
30
+ def build_pair(turns, max_tokens=512):
31
+ if len(turns) < max_turn_pairs * 2:
32
+ return None
33
+
34
+ # last N pairs
35
+ use_turns = turns[-(max_turn_pairs * 2):]
36
+
37
+ prompt = "chat:\n\n" + "\n\n".join(use_turns[:-1])
38
+ target = use_turns[-1].replace("Assistant: ", "", 1)
39
+
40
+ # --- safe trimming loop --------------------------------------------
41
+ for _ in range(max_turn_pairs): # at most 4 trims if max_turn_pairs=4
42
+ if len(tok.tokenize(prompt)) <= max_tokens:
43
+ break # fits → good
44
+ sep_pos = prompt.find("\n\n", len("chat:\n\n"))
45
+ if sep_pos == -1: # no more turns to drop
46
+ return None
47
+ prompt = "chat:\n\n" + prompt[sep_pos + 2:]
48
+ else:
49
+ # still too long after all trims
50
+ return None
51
+ # -------------------------------------------------------------------
52
+
53
+ if len(prompt) < 30 or len(target) < 10:
54
+ return None
55
+ if not mostly_ascii(prompt + target):
56
+ return None
57
+ return prompt, target
58
+
59
+
60
+ rows, kept = [], 0
61
+ for ex in tqdm(ds, desc="formatting"):
62
+ conv = ex.get("conversation")
63
+ if not isinstance(conv, list): continue
64
+ p = build_pair(format_turns(conv))
65
+ if p:
66
+ rows.append({"source": p[0], "target": p[1]})
67
+ kept += 1
68
+
69
+ print(f"✅ kept {kept} examples")
70
+
71
+ pd.DataFrame(rows).to_csv(
72
+ "chat_1turn.csv",
73
+ index=False,
74
+ quoting=csv.QUOTE_ALL, # preserves embedded newlines
75
+ encoding="utf-8"
76
+ )
77
+ print("💾 saved → t5_chat_4turn.csv")
main.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import matplotlib.pyplot as plt
5
+
6
+ # ===== Load precomputed embeddings =====
7
+ emb_data = torch.load("chat_embeddings.pt")
8
+
9
+ x_embeddings = emb_data["source"] # [N, D]
10
+ y_embeddings = emb_data["target"] # [N, D]
11
+
12
+ print("Source shape:", x_embeddings.shape)
13
+ print("Target shape:", y_embeddings.shape)
14
+
15
+ embedding_dim = x_embeddings.shape[1]
16
+ num_samples = x_embeddings.shape[0]
17
+
18
+ # ===== Device =====
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ print(f"Using device: {device}")
21
+
22
+ x_embeddings = x_embeddings.to(device)
23
+ y_embeddings = y_embeddings.to(device)
24
+
25
+ # ===== Define model =====
26
+ class SemanticMapper(nn.Module):
27
+ def __init__(self, dim):
28
+ super().__init__()
29
+ self.net = nn.Sequential(
30
+ nn.Linear(dim, dim * 2),
31
+ nn.ReLU(),
32
+ nn.Linear(dim * 2, dim)
33
+ )
34
+
35
+ def forward(self, x):
36
+ return self.net(x)
37
+
38
+ model = SemanticMapper(embedding_dim).to(device)
39
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
40
+ criterion = nn.CosineEmbeddingLoss()
41
+
42
+ # ===== Training config =====
43
+ epochs = 20
44
+ batch_size = 64
45
+ loss_history = []
46
+
47
+ # ===== Training loop =====
48
+ for epoch in range(epochs):
49
+ perm = torch.randperm(num_samples, device=device)
50
+ epoch_loss = 0.0
51
+ for i in range(0, num_samples, batch_size):
52
+ idx = perm[i:i + batch_size]
53
+ x_batch = x_embeddings[idx]
54
+ y_batch = y_embeddings[idx]
55
+ target = torch.ones(x_batch.size(0), device=device) # cosine target = +1
56
+
57
+ y_pred = model(x_batch)
58
+ loss = criterion(y_pred, y_batch, target)
59
+
60
+ optimizer.zero_grad()
61
+ loss.backward()
62
+ optimizer.step()
63
+
64
+ epoch_loss += loss.item()
65
+
66
+ avg_loss = epoch_loss / (num_samples / batch_size)
67
+ loss_history.append(avg_loss)
68
+ print(f"Epoch {epoch + 1}/{epochs} - Loss: {avg_loss:.6f}")
69
+
70
+ # ===== Plot loss curve =====
71
+ plt.plot(loss_history, marker="o")
72
+ plt.title("Training Loss (Cosine Similarity)")
73
+ plt.xlabel("Epoch")
74
+ plt.ylabel("Loss")
75
+ plt.grid(True)
76
+ plt.show()
77
+
78
+ # Save the trained model
79
+ torch.save(model.state_dict(), "semantic_mapper.pth")
pre_embed.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ from sentence_transformers import SentenceTransformer
4
+ import time
5
+ import os
6
+
7
+ # CONFIGURATION
8
+ INPUT_FILE = "chat_1turn.csv"
9
+ OUTPUT_FILE = "chat_embeddings.pt"
10
+ MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
11
+ BATCH_SIZE = 128 # Go big or go slow
12
+ USE_GPU = torch.cuda.is_available()
13
+ MAX_ROWS = 2000 # Set to e.g. 1000 for quick dev tests
14
+
15
+ # 🔧 Sanity checks
16
+ assert os.path.exists(INPUT_FILE), f"❌ File not found: {INPUT_FILE}"
17
+
18
+ # 🚀 Load model
19
+ print(f"🧠 Loading model: {MODEL_NAME} {'[GPU]' if USE_GPU else '[CPU]'}")
20
+ model = SentenceTransformer(MODEL_NAME, device="cuda" if USE_GPU else "cpu")
21
+
22
+ # 📂 Load data
23
+ print("📂 Reading CSV...")
24
+ df = pd.read_csv(INPUT_FILE)
25
+ assert 'source' in df.columns and 'target' in df.columns, "❌ Missing 'source' or 'target' column!"
26
+
27
+ if MAX_ROWS:
28
+ df = df.head(MAX_ROWS)
29
+
30
+ sources = df['source'].fillna("").tolist()
31
+ targets = df['target'].fillna("").tolist()
32
+
33
+ # ⏱️ Embed all at once
34
+ def embed_all(texts, label):
35
+ print(f"⚙️ Embedding {label} ({len(texts)} items)...")
36
+ start = time.time()
37
+ embeddings = model.encode(
38
+ texts,
39
+ batch_size=BATCH_SIZE,
40
+ convert_to_tensor=True,
41
+ normalize_embeddings=True,
42
+ show_progress_bar=True,
43
+ device="cuda" if USE_GPU else "cpu",
44
+ torch_dtype=torch.int8
45
+ )
46
+ print(f"✅ {label} embedding done in {time.time() - start:.2f}s")
47
+ return embeddings
48
+
49
+ source_tensor = embed_all(sources, "source")
50
+ target_tensor = embed_all(targets, "target")
51
+
52
+ # 💾 Save
53
+ print(f"💾 Saving to {OUTPUT_FILE}...")
54
+ torch.save({"source": source_tensor, "target": target_tensor}, OUTPUT_FILE)
55
+ print(f"✅ Saved {len(sources)} embeddings to {OUTPUT_FILE}")
test_chat.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # chat.py - use trained mapper + decoder interactively
3
+
4
+ import torch
5
+ from transformers import T5Tokenizer
6
+ from sentence_transformers import SentenceTransformer
7
+ import torch.nn as nn
8
+
9
+ # ===== CONFIG =====
10
+ MAPPER_PTH = "semantic_mapper.pth"
11
+ DECODER_PTH = "embedding_decoder.pth"
12
+ MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
13
+ MAX_LEN = 4096
14
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # ===== LOAD TOKENIZER =====
17
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
18
+ pad_id = tokenizer.pad_token_id
19
+ eos_id = tokenizer.eos_token_id
20
+
21
+ # ===== MODEL CLASSES (same defs as training) =====
22
+ class SemanticMapper(torch.nn.Module):
23
+ def __init__(self, dim):
24
+ super().__init__()
25
+ self.net = torch.nn.Sequential(
26
+ torch.nn.Linear(dim, dim * 2),
27
+ torch.nn.ReLU(),
28
+ torch.nn.Linear(dim * 2, dim)
29
+ )
30
+ def forward(self, x): return self.net(x)
31
+
32
+ class EmbeddingDecoder(nn.Module):
33
+ def __init__(self, input_dim, hidden_dim, vocab_size, p=0.2):
34
+ super().__init__()
35
+ self.bridge = nn.Linear(input_dim, hidden_dim) # emb -> h0
36
+ self.embed = nn.Embedding(vocab_size, hidden_dim) # token -> hidden
37
+ self.gru = nn.GRU(hidden_dim + input_dim, hidden_dim, batch_first=True)
38
+ self.ln = nn.LayerNorm(hidden_dim)
39
+ self.fc = nn.Linear(hidden_dim, vocab_size, bias=True)
40
+ self.drop = nn.Dropout(p)
41
+ # tie weights
42
+ self.fc.weight = self.embed.weight
43
+
44
+ @torch.no_grad()
45
+ def greedy_decode(self, emb_vec, max_len, start_id, eos_id):
46
+ B, _ = emb_vec.shape
47
+ h = torch.tanh(self.bridge(emb_vec)).unsqueeze(0)
48
+ inp = torch.full((B,1), start_id, dtype=torch.long, device=emb_vec.device)
49
+ out_ids = []
50
+ for _ in range(max_len):
51
+ token_h = self.drop(self.embed(inp)) # [B,1,H]
52
+ step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1)
53
+ out, h = self.gru(step_in, h)
54
+ out = self.ln(out.squeeze(1))
55
+ logits = self.fc(self.drop(out))
56
+ logits[:, pad_id] = -1e9
57
+ next_id = torch.argmax(logits, dim=-1)
58
+ out_ids.append(next_id.unsqueeze(1))
59
+ if (next_id == eos_id).all(): break
60
+ inp = next_id.unsqueeze(1)
61
+ return torch.cat(out_ids, dim=1)
62
+
63
+
64
+ # ===== LOAD MODELS =====
65
+ mapper_ckpt = torch.load(MAPPER_PTH, map_location=DEVICE)
66
+ mapper = SemanticMapper(mapper_ckpt["dim"]).to(DEVICE)
67
+ mapper.load_state_dict(mapper_ckpt["state_dict"])
68
+ mapper.eval()
69
+
70
+ dec_ckpt = torch.load(DECODER_PTH, map_location=DEVICE)
71
+ decoder = EmbeddingDecoder(dec_ckpt["dim"], 512, dec_ckpt["vocab_size"]).to(DEVICE)
72
+ decoder.load_state_dict(dec_ckpt["state_dict"])
73
+ decoder.eval()
74
+
75
+ embedder = SentenceTransformer(MODEL_NAME, device=DEVICE)
76
+
77
+ # ===== CHAT LOOP =====
78
+ def chat():
79
+ print("Chat ready. Type 'quit' to exit.")
80
+ while True:
81
+ user = input("User: ").strip()
82
+ if not user or user.lower() in {"quit","exit"}: break
83
+ x = embedder.encode([user], convert_to_tensor=True, device=DEVICE).detach().clone()
84
+ y_pred = mapper(x)
85
+ ids = decoder.greedy_decode(y_pred, max_len=MAX_LEN,
86
+ start_id=pad_id, eos_id=eos_id)[0].tolist()
87
+ reply = tokenizer.decode(ids, skip_special_tokens=True)
88
+ print("Bot:", reply)
89
+
90
+ if __name__ == "__main__":
91
+ chat()
test_embed.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import time
3
+ import torch
4
+
5
+ model = SentenceTransformer("Snowflake/snowflake-arctic-embed-xs", device="cuda")
6
+
7
+ texts = ["The quick brown fox jumps over the lazy dog."] * 1000
8
+
9
+ start = time.time()
10
+ embeddings = model.encode(
11
+ texts,
12
+ batch_size=512,
13
+ convert_to_tensor=True,
14
+ normalize_embeddings=True,
15
+ device="cuda",
16
+ torch_dtype=torch.int8
17
+ )
18
+ print(f"⏱️ Embedded 1000 items in {time.time() - start:.2f} seconds")
test_full.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import torch, torch.nn as nn, torch.optim as optim
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ from transformers import T5Tokenizer
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ # ==== Config ====
9
+ EMB_FILE = "chat_embeddings.pt" # {"source": [N,D], "target": [N,D]}
10
+ CSV_FILE = "chat_1turn.csv" # columns: source, target
11
+ MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
12
+ EPOCHS_MAPPER = 20
13
+ EPOCHS_DECODER = 160
14
+ BATCH_SIZE_MAP = 64
15
+ BATCH_SIZE_DEC = 64
16
+ LR_MAPPER = 1e-3
17
+ LR_DECODER = 1e-3
18
+ HIDDEN_DIM = 512
19
+ MAX_LEN = 64
20
+ PLOT_LOSS = False
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ print(f"Using device: {device}")
24
+
25
+ # ==== Load embeddings & CSV ====
26
+ emb = torch.load(EMB_FILE, map_location=device)
27
+ x_embeddings = emb["source"].to(device) # [N,D]
28
+ y_embeddings = emb["target"].to(device) # [N,D]
29
+ N, D = x_embeddings.shape
30
+ print(f"Loaded embeddings: N={N}, D={D}")
31
+
32
+ df = pd.read_csv(CSV_FILE)
33
+ assert "target" in df.columns
34
+ targets = df["target"].fillna("").tolist()
35
+
36
+ # ==== Mapper: x_emb -> y_emb ====
37
+ class SemanticMapper(nn.Module):
38
+ def __init__(self, dim):
39
+ super().__init__()
40
+ self.net = nn.Sequential(
41
+ nn.Linear(dim, dim*2), nn.ReLU(),
42
+ nn.Linear(dim*2, dim)
43
+ )
44
+ def forward(self, x): return self.net(x)
45
+
46
+ mapper = SemanticMapper(D).to(device)
47
+ opt_map = optim.Adam(mapper.parameters(), lr=LR_MAPPER)
48
+ crit_map = nn.CosineEmbeddingLoss()
49
+
50
+ print("\nTraining mapper...")
51
+ map_losses = []
52
+ for ep in range(EPOCHS_MAPPER):
53
+ perm = torch.randperm(N, device=device)
54
+ total = 0.0; steps = 0
55
+ for i in range(0, N, BATCH_SIZE_MAP):
56
+ idx = perm[i:i+BATCH_SIZE_MAP]
57
+ xb, yb = x_embeddings[idx], y_embeddings[idx]
58
+ tgt = torch.ones(xb.size(0), device=device)
59
+ pred = mapper(xb)
60
+ loss = crit_map(pred, yb, tgt)
61
+ opt_map.zero_grad(); loss.backward()
62
+ opt_map.step()
63
+ total += loss.item(); steps += 1
64
+ avg = total / max(1, steps)
65
+ map_losses.append(avg)
66
+ print(f"Mapper Epoch {ep+1}/{EPOCHS_MAPPER} - Loss: {avg:.6f}")
67
+
68
+ if PLOT_LOSS:
69
+ plt.figure(); plt.plot(map_losses, marker="o"); plt.title("Mapper Loss"); plt.grid(True); plt.show()
70
+
71
+ torch.save({"state_dict": mapper.state_dict(), "dim": D}, "semantic_mapper.pth")
72
+ print("Saved mapper -> semantic_mapper.pth")
73
+
74
+ # ==== Decoder: y_emb -> target text ====
75
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
76
+ tok = tokenizer(targets, padding=True, truncation=True, max_length=MAX_LEN,
77
+ return_tensors="pt", add_special_tokens=True)
78
+ labels = tok["input_ids"].to(device) # [N,L]
79
+ pad_id = tokenizer.pad_token_id
80
+ eos_id = tokenizer.eos_token_id # T5 uses </s> as EOS
81
+
82
+ # Build shifted inputs for strict teacher forcing:
83
+ # y_in[0] = BOS (use pad_id for T5), then y_in[t] = labels[t-1]
84
+ y_in = torch.full_like(labels, pad_id)
85
+ y_in[:, 1:] = labels[:, :-1]
86
+ y_out = labels # predict labels[t] given y_in[t]
87
+
88
+ class EmbeddingDecoder(nn.Module):
89
+ """
90
+ Strong conditioning: concat emb each step.
91
+ Weight tying: embed.weight = fc.weight.
92
+ Deterministic teacher forcing via pre-built y_in (no ratios).
93
+ """
94
+ def __init__(self, input_dim, hidden_dim, vocab_size, p=0.2):
95
+ super().__init__()
96
+ self.bridge = nn.Linear(input_dim, hidden_dim) # emb -> h0
97
+ self.embed = nn.Embedding(vocab_size, hidden_dim) # token -> hidden
98
+ self.gru = nn.GRU(hidden_dim + input_dim, hidden_dim, batch_first=True)
99
+ self.ln = nn.LayerNorm(hidden_dim)
100
+ self.fc = nn.Linear(hidden_dim, vocab_size, bias=True)
101
+ self.drop = nn.Dropout(p)
102
+ # Tie weights
103
+ self.fc.weight = self.embed.weight
104
+
105
+ def forward_teacher_forced(self, emb_vec, in_ids, max_len):
106
+ """
107
+ emb_vec: [B,D], in_ids: [B,L] (strict teacher forcing inputs)
108
+ Returns logits: [B,L,V]
109
+ """
110
+ B, D_in = emb_vec.shape
111
+ H0 = torch.tanh(self.bridge(emb_vec)).unsqueeze(0) # [1,B,H]
112
+ logits_all = []
113
+ h = H0
114
+ for t in range(max_len):
115
+ inp = in_ids[:, t].unsqueeze(1) # [B,1]
116
+ token_h = self.drop(self.embed(inp)) # [B,1,H]
117
+ step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1) # [B,1,H+D]
118
+ out, h = self.gru(step_in, h) # [B,1,H]
119
+ out = self.ln(out.squeeze(1)) # [B,H]
120
+ logits = self.fc(self.drop(out)) # [B,V]
121
+ logits_all.append(logits.unsqueeze(1))
122
+ return torch.cat(logits_all, dim=1) # [B,L,V]
123
+
124
+ @torch.no_grad()
125
+ def greedy_decode(self, emb_vec, max_len, start_id, eos_id):
126
+ """
127
+ Pure greedy with EOS stop; forbids PAD to reduce loops.
128
+ """
129
+ B, _ = emb_vec.shape
130
+ h = torch.tanh(self.bridge(emb_vec)).unsqueeze(0)
131
+ inp = torch.full((B,1), start_id, dtype=torch.long, device=emb_vec.device)
132
+ out_ids = []
133
+ done = torch.zeros(B, dtype=torch.bool, device=emb_vec.device)
134
+
135
+ for _ in range(max_len):
136
+ token_h = self.embed(inp) # [B,1,H]
137
+ step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1)
138
+ out, h = self.gru(step_in, h)
139
+ logits = self.fc(out.squeeze(1)) # [B,V]
140
+ logits[:, pad_id] = -1e9 # discourage PAD
141
+ next_id = torch.argmax(logits, dim=-1) # [B]
142
+ out_ids.append(next_id.unsqueeze(1))
143
+ done |= (next_id == eos_id)
144
+ if done.all(): break
145
+ inp = next_id.unsqueeze(1)
146
+ return torch.cat(out_ids, dim=1) # [B,T]
147
+
148
+ decoder = EmbeddingDecoder(D, HIDDEN_DIM, tokenizer.vocab_size).to(device)
149
+ opt_dec = optim.Adam(decoder.parameters(), lr=LR_DECODER)
150
+ crit_dec = nn.CrossEntropyLoss(ignore_index=pad_id) # no smoothing (small N)
151
+
152
+ print("\nTraining decoder...")
153
+ dec_losses = []
154
+ steps = (N + BATCH_SIZE_DEC - 1) // BATCH_SIZE_DEC
155
+ for ep in range(EPOCHS_DECODER):
156
+ perm = torch.randperm(N, device=device)
157
+ total = 0.0
158
+ for i in range(0, N, BATCH_SIZE_DEC):
159
+ idx = perm[i:i+BATCH_SIZE_DEC]
160
+ eb = y_embeddings[idx] # condition on TRUE target-space embeddings
161
+ yin = y_in[idx] # shifted inputs
162
+ yout = y_out[idx] # labels
163
+
164
+ opt_dec.zero_grad()
165
+ logits = decoder.forward_teacher_forced(eb, yin, max_len=yout.size(1)) # [B,L,V]
166
+ loss = crit_dec(logits.reshape(-1, logits.size(-1)), yout.reshape(-1))
167
+ loss.backward()
168
+ nn.utils.clip_grad_norm_(decoder.parameters(), 1.0)
169
+ opt_dec.step()
170
+ total += loss.item()
171
+ avg = total / max(1, steps)
172
+ dec_losses.append(avg)
173
+ print(f"Decoder Epoch {ep+1}/{EPOCHS_DECODER} - Loss: {avg:.4f}")
174
+
175
+ if PLOT_LOSS:
176
+ plt.figure(); plt.plot(dec_losses, marker="o"); plt.title("Decoder Loss"); plt.grid(True); plt.show()
177
+
178
+ torch.save({"state_dict": decoder.state_dict(), "dim": D, "vocab_size": tokenizer.vocab_size},
179
+ "embedding_decoder.pth")
180
+ print("Saved decoder -> embedding_decoder.pth")
181
+
182
+ # ==== E2E inference ====
183
+ embedder = SentenceTransformer(MODEL_NAME, device=device)
184
+ try:
185
+ dim = embedder.get_sentence_embedding_dimension()
186
+ if dim != D:
187
+ raise RuntimeError(f"Embedder dim {dim} != training dim {D}. Regenerate embeddings with same MODEL_NAME.")
188
+ except Exception:
189
+ pass
190
+
191
+ @torch.no_grad()
192
+ def generate(text: str, max_len: int = 24) -> str:
193
+ # source -> x_emb
194
+ x = embedder.encode([text], convert_to_tensor=True, device=device) # [1,D]
195
+ # map -> y_emb
196
+ y_pred = mapper(x) # [1,D]
197
+ # decode y_emb -> text
198
+ ids = decoder.greedy_decode(y_pred, max_len=max_len, start_id=pad_id, eos_id=eos_id)[0].tolist()
199
+ return tokenizer.decode(ids, skip_special_tokens=True)
200
+
201
+ print("\nE2E test:")
202
+ inp = "User: Hi"
203
+ print(f"{inp} ->", generate(inp))
testcuda.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ print(torch.cuda.is_available())
4
+ print(torch.cuda.device_count())
5
+ print(torch.cuda.current_device())
6
+ print(torch.cuda.get_device_name(0))