V2L-Alpha1-Example1 / test_full.py
openagi-agi's picture
Upload 8 files
7cd7caf verified
#!/usr/bin/env python3
import torch, torch.nn as nn, torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt
from transformers import T5Tokenizer
from sentence_transformers import SentenceTransformer
# ==== Config ====
EMB_FILE = "chat_embeddings.pt" # {"source": [N,D], "target": [N,D]}
CSV_FILE = "chat_1turn.csv" # columns: source, target
MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
EPOCHS_MAPPER = 20
EPOCHS_DECODER = 160
BATCH_SIZE_MAP = 64
BATCH_SIZE_DEC = 64
LR_MAPPER = 1e-3
LR_DECODER = 1e-3
HIDDEN_DIM = 512
MAX_LEN = 64
PLOT_LOSS = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# ==== Load embeddings & CSV ====
emb = torch.load(EMB_FILE, map_location=device)
x_embeddings = emb["source"].to(device) # [N,D]
y_embeddings = emb["target"].to(device) # [N,D]
N, D = x_embeddings.shape
print(f"Loaded embeddings: N={N}, D={D}")
df = pd.read_csv(CSV_FILE)
assert "target" in df.columns
targets = df["target"].fillna("").tolist()
# ==== Mapper: x_emb -> y_emb ====
class SemanticMapper(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim*2), nn.ReLU(),
nn.Linear(dim*2, dim)
)
def forward(self, x): return self.net(x)
mapper = SemanticMapper(D).to(device)
opt_map = optim.Adam(mapper.parameters(), lr=LR_MAPPER)
crit_map = nn.CosineEmbeddingLoss()
print("\nTraining mapper...")
map_losses = []
for ep in range(EPOCHS_MAPPER):
perm = torch.randperm(N, device=device)
total = 0.0; steps = 0
for i in range(0, N, BATCH_SIZE_MAP):
idx = perm[i:i+BATCH_SIZE_MAP]
xb, yb = x_embeddings[idx], y_embeddings[idx]
tgt = torch.ones(xb.size(0), device=device)
pred = mapper(xb)
loss = crit_map(pred, yb, tgt)
opt_map.zero_grad(); loss.backward()
opt_map.step()
total += loss.item(); steps += 1
avg = total / max(1, steps)
map_losses.append(avg)
print(f"Mapper Epoch {ep+1}/{EPOCHS_MAPPER} - Loss: {avg:.6f}")
if PLOT_LOSS:
plt.figure(); plt.plot(map_losses, marker="o"); plt.title("Mapper Loss"); plt.grid(True); plt.show()
torch.save({"state_dict": mapper.state_dict(), "dim": D}, "semantic_mapper.pth")
print("Saved mapper -> semantic_mapper.pth")
# ==== Decoder: y_emb -> target text ====
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tok = tokenizer(targets, padding=True, truncation=True, max_length=MAX_LEN,
return_tensors="pt", add_special_tokens=True)
labels = tok["input_ids"].to(device) # [N,L]
pad_id = tokenizer.pad_token_id
eos_id = tokenizer.eos_token_id # T5 uses </s> as EOS
# Build shifted inputs for strict teacher forcing:
# y_in[0] = BOS (use pad_id for T5), then y_in[t] = labels[t-1]
y_in = torch.full_like(labels, pad_id)
y_in[:, 1:] = labels[:, :-1]
y_out = labels # predict labels[t] given y_in[t]
class EmbeddingDecoder(nn.Module):
"""
Strong conditioning: concat emb each step.
Weight tying: embed.weight = fc.weight.
Deterministic teacher forcing via pre-built y_in (no ratios).
"""
def __init__(self, input_dim, hidden_dim, vocab_size, p=0.2):
super().__init__()
self.bridge = nn.Linear(input_dim, hidden_dim) # emb -> h0
self.embed = nn.Embedding(vocab_size, hidden_dim) # token -> hidden
self.gru = nn.GRU(hidden_dim + input_dim, hidden_dim, batch_first=True)
self.ln = nn.LayerNorm(hidden_dim)
self.fc = nn.Linear(hidden_dim, vocab_size, bias=True)
self.drop = nn.Dropout(p)
# Tie weights
self.fc.weight = self.embed.weight
def forward_teacher_forced(self, emb_vec, in_ids, max_len):
"""
emb_vec: [B,D], in_ids: [B,L] (strict teacher forcing inputs)
Returns logits: [B,L,V]
"""
B, D_in = emb_vec.shape
H0 = torch.tanh(self.bridge(emb_vec)).unsqueeze(0) # [1,B,H]
logits_all = []
h = H0
for t in range(max_len):
inp = in_ids[:, t].unsqueeze(1) # [B,1]
token_h = self.drop(self.embed(inp)) # [B,1,H]
step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1) # [B,1,H+D]
out, h = self.gru(step_in, h) # [B,1,H]
out = self.ln(out.squeeze(1)) # [B,H]
logits = self.fc(self.drop(out)) # [B,V]
logits_all.append(logits.unsqueeze(1))
return torch.cat(logits_all, dim=1) # [B,L,V]
@torch.no_grad()
def greedy_decode(self, emb_vec, max_len, start_id, eos_id):
"""
Pure greedy with EOS stop; forbids PAD to reduce loops.
"""
B, _ = emb_vec.shape
h = torch.tanh(self.bridge(emb_vec)).unsqueeze(0)
inp = torch.full((B,1), start_id, dtype=torch.long, device=emb_vec.device)
out_ids = []
done = torch.zeros(B, dtype=torch.bool, device=emb_vec.device)
for _ in range(max_len):
token_h = self.embed(inp) # [B,1,H]
step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1)
out, h = self.gru(step_in, h)
logits = self.fc(out.squeeze(1)) # [B,V]
logits[:, pad_id] = -1e9 # discourage PAD
next_id = torch.argmax(logits, dim=-1) # [B]
out_ids.append(next_id.unsqueeze(1))
done |= (next_id == eos_id)
if done.all(): break
inp = next_id.unsqueeze(1)
return torch.cat(out_ids, dim=1) # [B,T]
decoder = EmbeddingDecoder(D, HIDDEN_DIM, tokenizer.vocab_size).to(device)
opt_dec = optim.Adam(decoder.parameters(), lr=LR_DECODER)
crit_dec = nn.CrossEntropyLoss(ignore_index=pad_id) # no smoothing (small N)
print("\nTraining decoder...")
dec_losses = []
steps = (N + BATCH_SIZE_DEC - 1) // BATCH_SIZE_DEC
for ep in range(EPOCHS_DECODER):
perm = torch.randperm(N, device=device)
total = 0.0
for i in range(0, N, BATCH_SIZE_DEC):
idx = perm[i:i+BATCH_SIZE_DEC]
eb = y_embeddings[idx] # condition on TRUE target-space embeddings
yin = y_in[idx] # shifted inputs
yout = y_out[idx] # labels
opt_dec.zero_grad()
logits = decoder.forward_teacher_forced(eb, yin, max_len=yout.size(1)) # [B,L,V]
loss = crit_dec(logits.reshape(-1, logits.size(-1)), yout.reshape(-1))
loss.backward()
nn.utils.clip_grad_norm_(decoder.parameters(), 1.0)
opt_dec.step()
total += loss.item()
avg = total / max(1, steps)
dec_losses.append(avg)
print(f"Decoder Epoch {ep+1}/{EPOCHS_DECODER} - Loss: {avg:.4f}")
if PLOT_LOSS:
plt.figure(); plt.plot(dec_losses, marker="o"); plt.title("Decoder Loss"); plt.grid(True); plt.show()
torch.save({"state_dict": decoder.state_dict(), "dim": D, "vocab_size": tokenizer.vocab_size},
"embedding_decoder.pth")
print("Saved decoder -> embedding_decoder.pth")
# ==== E2E inference ====
embedder = SentenceTransformer(MODEL_NAME, device=device)
try:
dim = embedder.get_sentence_embedding_dimension()
if dim != D:
raise RuntimeError(f"Embedder dim {dim} != training dim {D}. Regenerate embeddings with same MODEL_NAME.")
except Exception:
pass
@torch.no_grad()
def generate(text: str, max_len: int = 24) -> str:
# source -> x_emb
x = embedder.encode([text], convert_to_tensor=True, device=device) # [1,D]
# map -> y_emb
y_pred = mapper(x) # [1,D]
# decode y_emb -> text
ids = decoder.greedy_decode(y_pred, max_len=max_len, start_id=pad_id, eos_id=eos_id)[0].tolist()
return tokenizer.decode(ids, skip_special_tokens=True)
print("\nE2E test:")
inp = "User: Hi"
print(f"{inp} ->", generate(inp))