|
|
| import torch, torch.nn as nn, torch.optim as optim
|
| import pandas as pd
|
| import matplotlib.pyplot as plt
|
| from transformers import T5Tokenizer
|
| from sentence_transformers import SentenceTransformer
|
|
|
|
|
| EMB_FILE = "chat_embeddings.pt"
|
| CSV_FILE = "chat_1turn.csv"
|
| MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
| EPOCHS_MAPPER = 20
|
| EPOCHS_DECODER = 160
|
| BATCH_SIZE_MAP = 64
|
| BATCH_SIZE_DEC = 64
|
| LR_MAPPER = 1e-3
|
| LR_DECODER = 1e-3
|
| HIDDEN_DIM = 512
|
| MAX_LEN = 64
|
| PLOT_LOSS = False
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"Using device: {device}")
|
|
|
|
|
| emb = torch.load(EMB_FILE, map_location=device)
|
| x_embeddings = emb["source"].to(device)
|
| y_embeddings = emb["target"].to(device)
|
| N, D = x_embeddings.shape
|
| print(f"Loaded embeddings: N={N}, D={D}")
|
|
|
| df = pd.read_csv(CSV_FILE)
|
| assert "target" in df.columns
|
| targets = df["target"].fillna("").tolist()
|
|
|
|
|
| class SemanticMapper(nn.Module):
|
| def __init__(self, dim):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
| nn.Linear(dim, dim*2), nn.ReLU(),
|
| nn.Linear(dim*2, dim)
|
| )
|
| def forward(self, x): return self.net(x)
|
|
|
| mapper = SemanticMapper(D).to(device)
|
| opt_map = optim.Adam(mapper.parameters(), lr=LR_MAPPER)
|
| crit_map = nn.CosineEmbeddingLoss()
|
|
|
| print("\nTraining mapper...")
|
| map_losses = []
|
| for ep in range(EPOCHS_MAPPER):
|
| perm = torch.randperm(N, device=device)
|
| total = 0.0; steps = 0
|
| for i in range(0, N, BATCH_SIZE_MAP):
|
| idx = perm[i:i+BATCH_SIZE_MAP]
|
| xb, yb = x_embeddings[idx], y_embeddings[idx]
|
| tgt = torch.ones(xb.size(0), device=device)
|
| pred = mapper(xb)
|
| loss = crit_map(pred, yb, tgt)
|
| opt_map.zero_grad(); loss.backward()
|
| opt_map.step()
|
| total += loss.item(); steps += 1
|
| avg = total / max(1, steps)
|
| map_losses.append(avg)
|
| print(f"Mapper Epoch {ep+1}/{EPOCHS_MAPPER} - Loss: {avg:.6f}")
|
|
|
| if PLOT_LOSS:
|
| plt.figure(); plt.plot(map_losses, marker="o"); plt.title("Mapper Loss"); plt.grid(True); plt.show()
|
|
|
| torch.save({"state_dict": mapper.state_dict(), "dim": D}, "semantic_mapper.pth")
|
| print("Saved mapper -> semantic_mapper.pth")
|
|
|
|
|
| tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| tok = tokenizer(targets, padding=True, truncation=True, max_length=MAX_LEN,
|
| return_tensors="pt", add_special_tokens=True)
|
| labels = tok["input_ids"].to(device)
|
| pad_id = tokenizer.pad_token_id
|
| eos_id = tokenizer.eos_token_id
|
|
|
|
|
|
|
| y_in = torch.full_like(labels, pad_id)
|
| y_in[:, 1:] = labels[:, :-1]
|
| y_out = labels
|
|
|
| class EmbeddingDecoder(nn.Module):
|
| """
|
| Strong conditioning: concat emb each step.
|
| Weight tying: embed.weight = fc.weight.
|
| Deterministic teacher forcing via pre-built y_in (no ratios).
|
| """
|
| def __init__(self, input_dim, hidden_dim, vocab_size, p=0.2):
|
| super().__init__()
|
| self.bridge = nn.Linear(input_dim, hidden_dim)
|
| self.embed = nn.Embedding(vocab_size, hidden_dim)
|
| self.gru = nn.GRU(hidden_dim + input_dim, hidden_dim, batch_first=True)
|
| self.ln = nn.LayerNorm(hidden_dim)
|
| self.fc = nn.Linear(hidden_dim, vocab_size, bias=True)
|
| self.drop = nn.Dropout(p)
|
|
|
| self.fc.weight = self.embed.weight
|
|
|
| def forward_teacher_forced(self, emb_vec, in_ids, max_len):
|
| """
|
| emb_vec: [B,D], in_ids: [B,L] (strict teacher forcing inputs)
|
| Returns logits: [B,L,V]
|
| """
|
| B, D_in = emb_vec.shape
|
| H0 = torch.tanh(self.bridge(emb_vec)).unsqueeze(0)
|
| logits_all = []
|
| h = H0
|
| for t in range(max_len):
|
| inp = in_ids[:, t].unsqueeze(1)
|
| token_h = self.drop(self.embed(inp))
|
| step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1)
|
| out, h = self.gru(step_in, h)
|
| out = self.ln(out.squeeze(1))
|
| logits = self.fc(self.drop(out))
|
| logits_all.append(logits.unsqueeze(1))
|
| return torch.cat(logits_all, dim=1)
|
|
|
| @torch.no_grad()
|
| def greedy_decode(self, emb_vec, max_len, start_id, eos_id):
|
| """
|
| Pure greedy with EOS stop; forbids PAD to reduce loops.
|
| """
|
| B, _ = emb_vec.shape
|
| h = torch.tanh(self.bridge(emb_vec)).unsqueeze(0)
|
| inp = torch.full((B,1), start_id, dtype=torch.long, device=emb_vec.device)
|
| out_ids = []
|
| done = torch.zeros(B, dtype=torch.bool, device=emb_vec.device)
|
|
|
| for _ in range(max_len):
|
| token_h = self.embed(inp)
|
| step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1)
|
| out, h = self.gru(step_in, h)
|
| logits = self.fc(out.squeeze(1))
|
| logits[:, pad_id] = -1e9
|
| next_id = torch.argmax(logits, dim=-1)
|
| out_ids.append(next_id.unsqueeze(1))
|
| done |= (next_id == eos_id)
|
| if done.all(): break
|
| inp = next_id.unsqueeze(1)
|
| return torch.cat(out_ids, dim=1)
|
|
|
| decoder = EmbeddingDecoder(D, HIDDEN_DIM, tokenizer.vocab_size).to(device)
|
| opt_dec = optim.Adam(decoder.parameters(), lr=LR_DECODER)
|
| crit_dec = nn.CrossEntropyLoss(ignore_index=pad_id)
|
|
|
| print("\nTraining decoder...")
|
| dec_losses = []
|
| steps = (N + BATCH_SIZE_DEC - 1) // BATCH_SIZE_DEC
|
| for ep in range(EPOCHS_DECODER):
|
| perm = torch.randperm(N, device=device)
|
| total = 0.0
|
| for i in range(0, N, BATCH_SIZE_DEC):
|
| idx = perm[i:i+BATCH_SIZE_DEC]
|
| eb = y_embeddings[idx]
|
| yin = y_in[idx]
|
| yout = y_out[idx]
|
|
|
| opt_dec.zero_grad()
|
| logits = decoder.forward_teacher_forced(eb, yin, max_len=yout.size(1))
|
| loss = crit_dec(logits.reshape(-1, logits.size(-1)), yout.reshape(-1))
|
| loss.backward()
|
| nn.utils.clip_grad_norm_(decoder.parameters(), 1.0)
|
| opt_dec.step()
|
| total += loss.item()
|
| avg = total / max(1, steps)
|
| dec_losses.append(avg)
|
| print(f"Decoder Epoch {ep+1}/{EPOCHS_DECODER} - Loss: {avg:.4f}")
|
|
|
| if PLOT_LOSS:
|
| plt.figure(); plt.plot(dec_losses, marker="o"); plt.title("Decoder Loss"); plt.grid(True); plt.show()
|
|
|
| torch.save({"state_dict": decoder.state_dict(), "dim": D, "vocab_size": tokenizer.vocab_size},
|
| "embedding_decoder.pth")
|
| print("Saved decoder -> embedding_decoder.pth")
|
|
|
|
|
| embedder = SentenceTransformer(MODEL_NAME, device=device)
|
| try:
|
| dim = embedder.get_sentence_embedding_dimension()
|
| if dim != D:
|
| raise RuntimeError(f"Embedder dim {dim} != training dim {D}. Regenerate embeddings with same MODEL_NAME.")
|
| except Exception:
|
| pass
|
|
|
| @torch.no_grad()
|
| def generate(text: str, max_len: int = 24) -> str:
|
|
|
| x = embedder.encode([text], convert_to_tensor=True, device=device)
|
|
|
| y_pred = mapper(x)
|
|
|
| ids = decoder.greedy_decode(y_pred, max_len=max_len, start_id=pad_id, eos_id=eos_id)[0].tolist()
|
| return tokenizer.decode(ids, skip_special_tokens=True)
|
|
|
| print("\nE2E test:")
|
| inp = "User: Hi"
|
| print(f"{inp} ->", generate(inp))
|
|
|