#!/usr/bin/env python3 import torch, torch.nn as nn, torch.optim as optim import pandas as pd import matplotlib.pyplot as plt from transformers import T5Tokenizer from sentence_transformers import SentenceTransformer # ==== Config ==== EMB_FILE = "chat_embeddings.pt" # {"source": [N,D], "target": [N,D]} CSV_FILE = "chat_1turn.csv" # columns: source, target MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0" EPOCHS_MAPPER = 20 EPOCHS_DECODER = 160 BATCH_SIZE_MAP = 64 BATCH_SIZE_DEC = 64 LR_MAPPER = 1e-3 LR_DECODER = 1e-3 HIDDEN_DIM = 512 MAX_LEN = 64 PLOT_LOSS = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # ==== Load embeddings & CSV ==== emb = torch.load(EMB_FILE, map_location=device) x_embeddings = emb["source"].to(device) # [N,D] y_embeddings = emb["target"].to(device) # [N,D] N, D = x_embeddings.shape print(f"Loaded embeddings: N={N}, D={D}") df = pd.read_csv(CSV_FILE) assert "target" in df.columns targets = df["target"].fillna("").tolist() # ==== Mapper: x_emb -> y_emb ==== class SemanticMapper(nn.Module): def __init__(self, dim): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim*2), nn.ReLU(), nn.Linear(dim*2, dim) ) def forward(self, x): return self.net(x) mapper = SemanticMapper(D).to(device) opt_map = optim.Adam(mapper.parameters(), lr=LR_MAPPER) crit_map = nn.CosineEmbeddingLoss() print("\nTraining mapper...") map_losses = [] for ep in range(EPOCHS_MAPPER): perm = torch.randperm(N, device=device) total = 0.0; steps = 0 for i in range(0, N, BATCH_SIZE_MAP): idx = perm[i:i+BATCH_SIZE_MAP] xb, yb = x_embeddings[idx], y_embeddings[idx] tgt = torch.ones(xb.size(0), device=device) pred = mapper(xb) loss = crit_map(pred, yb, tgt) opt_map.zero_grad(); loss.backward() opt_map.step() total += loss.item(); steps += 1 avg = total / max(1, steps) map_losses.append(avg) print(f"Mapper Epoch {ep+1}/{EPOCHS_MAPPER} - Loss: {avg:.6f}") if PLOT_LOSS: plt.figure(); plt.plot(map_losses, marker="o"); plt.title("Mapper Loss"); plt.grid(True); plt.show() torch.save({"state_dict": mapper.state_dict(), "dim": D}, "semantic_mapper.pth") print("Saved mapper -> semantic_mapper.pth") # ==== Decoder: y_emb -> target text ==== tokenizer = T5Tokenizer.from_pretrained("t5-small") tok = tokenizer(targets, padding=True, truncation=True, max_length=MAX_LEN, return_tensors="pt", add_special_tokens=True) labels = tok["input_ids"].to(device) # [N,L] pad_id = tokenizer.pad_token_id eos_id = tokenizer.eos_token_id # T5 uses as EOS # Build shifted inputs for strict teacher forcing: # y_in[0] = BOS (use pad_id for T5), then y_in[t] = labels[t-1] y_in = torch.full_like(labels, pad_id) y_in[:, 1:] = labels[:, :-1] y_out = labels # predict labels[t] given y_in[t] class EmbeddingDecoder(nn.Module): """ Strong conditioning: concat emb each step. Weight tying: embed.weight = fc.weight. Deterministic teacher forcing via pre-built y_in (no ratios). """ def __init__(self, input_dim, hidden_dim, vocab_size, p=0.2): super().__init__() self.bridge = nn.Linear(input_dim, hidden_dim) # emb -> h0 self.embed = nn.Embedding(vocab_size, hidden_dim) # token -> hidden self.gru = nn.GRU(hidden_dim + input_dim, hidden_dim, batch_first=True) self.ln = nn.LayerNorm(hidden_dim) self.fc = nn.Linear(hidden_dim, vocab_size, bias=True) self.drop = nn.Dropout(p) # Tie weights self.fc.weight = self.embed.weight def forward_teacher_forced(self, emb_vec, in_ids, max_len): """ emb_vec: [B,D], in_ids: [B,L] (strict teacher forcing inputs) Returns logits: [B,L,V] """ B, D_in = emb_vec.shape H0 = torch.tanh(self.bridge(emb_vec)).unsqueeze(0) # [1,B,H] logits_all = [] h = H0 for t in range(max_len): inp = in_ids[:, t].unsqueeze(1) # [B,1] token_h = self.drop(self.embed(inp)) # [B,1,H] step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1) # [B,1,H+D] out, h = self.gru(step_in, h) # [B,1,H] out = self.ln(out.squeeze(1)) # [B,H] logits = self.fc(self.drop(out)) # [B,V] logits_all.append(logits.unsqueeze(1)) return torch.cat(logits_all, dim=1) # [B,L,V] @torch.no_grad() def greedy_decode(self, emb_vec, max_len, start_id, eos_id): """ Pure greedy with EOS stop; forbids PAD to reduce loops. """ B, _ = emb_vec.shape h = torch.tanh(self.bridge(emb_vec)).unsqueeze(0) inp = torch.full((B,1), start_id, dtype=torch.long, device=emb_vec.device) out_ids = [] done = torch.zeros(B, dtype=torch.bool, device=emb_vec.device) for _ in range(max_len): token_h = self.embed(inp) # [B,1,H] step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1) out, h = self.gru(step_in, h) logits = self.fc(out.squeeze(1)) # [B,V] logits[:, pad_id] = -1e9 # discourage PAD next_id = torch.argmax(logits, dim=-1) # [B] out_ids.append(next_id.unsqueeze(1)) done |= (next_id == eos_id) if done.all(): break inp = next_id.unsqueeze(1) return torch.cat(out_ids, dim=1) # [B,T] decoder = EmbeddingDecoder(D, HIDDEN_DIM, tokenizer.vocab_size).to(device) opt_dec = optim.Adam(decoder.parameters(), lr=LR_DECODER) crit_dec = nn.CrossEntropyLoss(ignore_index=pad_id) # no smoothing (small N) print("\nTraining decoder...") dec_losses = [] steps = (N + BATCH_SIZE_DEC - 1) // BATCH_SIZE_DEC for ep in range(EPOCHS_DECODER): perm = torch.randperm(N, device=device) total = 0.0 for i in range(0, N, BATCH_SIZE_DEC): idx = perm[i:i+BATCH_SIZE_DEC] eb = y_embeddings[idx] # condition on TRUE target-space embeddings yin = y_in[idx] # shifted inputs yout = y_out[idx] # labels opt_dec.zero_grad() logits = decoder.forward_teacher_forced(eb, yin, max_len=yout.size(1)) # [B,L,V] loss = crit_dec(logits.reshape(-1, logits.size(-1)), yout.reshape(-1)) loss.backward() nn.utils.clip_grad_norm_(decoder.parameters(), 1.0) opt_dec.step() total += loss.item() avg = total / max(1, steps) dec_losses.append(avg) print(f"Decoder Epoch {ep+1}/{EPOCHS_DECODER} - Loss: {avg:.4f}") if PLOT_LOSS: plt.figure(); plt.plot(dec_losses, marker="o"); plt.title("Decoder Loss"); plt.grid(True); plt.show() torch.save({"state_dict": decoder.state_dict(), "dim": D, "vocab_size": tokenizer.vocab_size}, "embedding_decoder.pth") print("Saved decoder -> embedding_decoder.pth") # ==== E2E inference ==== embedder = SentenceTransformer(MODEL_NAME, device=device) try: dim = embedder.get_sentence_embedding_dimension() if dim != D: raise RuntimeError(f"Embedder dim {dim} != training dim {D}. Regenerate embeddings with same MODEL_NAME.") except Exception: pass @torch.no_grad() def generate(text: str, max_len: int = 24) -> str: # source -> x_emb x = embedder.encode([text], convert_to_tensor=True, device=device) # [1,D] # map -> y_emb y_pred = mapper(x) # [1,D] # decode y_emb -> text ids = decoder.greedy_decode(y_pred, max_len=max_len, start_id=pad_id, eos_id=eos_id)[0].tolist() return tokenizer.decode(ids, skip_special_tokens=True) print("\nE2E test:") inp = "User: Hi" print(f"{inp} ->", generate(inp))