|
|
|
|
|
import torch, torch.nn as nn, torch.optim as optim
|
|
|
import pandas as pd
|
|
|
import matplotlib.pyplot as plt
|
|
|
from transformers import T5Tokenizer
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
EMB_FILE = "chat_embeddings.pt"
|
|
|
CSV_FILE = "chat_1turn.csv"
|
|
|
MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
|
|
EPOCHS_MAPPER = 20
|
|
|
EPOCHS_DECODER = 160
|
|
|
BATCH_SIZE_MAP = 64
|
|
|
BATCH_SIZE_DEC = 64
|
|
|
LR_MAPPER = 1e-3
|
|
|
LR_DECODER = 1e-3
|
|
|
HIDDEN_DIM = 512
|
|
|
MAX_LEN = 64
|
|
|
PLOT_LOSS = False
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
emb = torch.load(EMB_FILE, map_location=device)
|
|
|
x_embeddings = emb["source"].to(device)
|
|
|
y_embeddings = emb["target"].to(device)
|
|
|
N, D = x_embeddings.shape
|
|
|
print(f"Loaded embeddings: N={N}, D={D}")
|
|
|
|
|
|
df = pd.read_csv(CSV_FILE)
|
|
|
assert "target" in df.columns
|
|
|
targets = df["target"].fillna("").tolist()
|
|
|
|
|
|
|
|
|
class SemanticMapper(nn.Module):
|
|
|
def __init__(self, dim):
|
|
|
super().__init__()
|
|
|
self.net = nn.Sequential(
|
|
|
nn.Linear(dim, dim*2), nn.ReLU(),
|
|
|
nn.Linear(dim*2, dim)
|
|
|
)
|
|
|
def forward(self, x): return self.net(x)
|
|
|
|
|
|
mapper = SemanticMapper(D).to(device)
|
|
|
opt_map = optim.Adam(mapper.parameters(), lr=LR_MAPPER)
|
|
|
crit_map = nn.CosineEmbeddingLoss()
|
|
|
|
|
|
print("\nTraining mapper...")
|
|
|
map_losses = []
|
|
|
for ep in range(EPOCHS_MAPPER):
|
|
|
perm = torch.randperm(N, device=device)
|
|
|
total = 0.0; steps = 0
|
|
|
for i in range(0, N, BATCH_SIZE_MAP):
|
|
|
idx = perm[i:i+BATCH_SIZE_MAP]
|
|
|
xb, yb = x_embeddings[idx], y_embeddings[idx]
|
|
|
tgt = torch.ones(xb.size(0), device=device)
|
|
|
pred = mapper(xb)
|
|
|
loss = crit_map(pred, yb, tgt)
|
|
|
opt_map.zero_grad(); loss.backward()
|
|
|
opt_map.step()
|
|
|
total += loss.item(); steps += 1
|
|
|
avg = total / max(1, steps)
|
|
|
map_losses.append(avg)
|
|
|
print(f"Mapper Epoch {ep+1}/{EPOCHS_MAPPER} - Loss: {avg:.6f}")
|
|
|
|
|
|
if PLOT_LOSS:
|
|
|
plt.figure(); plt.plot(map_losses, marker="o"); plt.title("Mapper Loss"); plt.grid(True); plt.show()
|
|
|
|
|
|
torch.save({"state_dict": mapper.state_dict(), "dim": D}, "semantic_mapper.pth")
|
|
|
print("Saved mapper -> semantic_mapper.pth")
|
|
|
|
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
|
|
tok = tokenizer(targets, padding=True, truncation=True, max_length=MAX_LEN,
|
|
|
return_tensors="pt", add_special_tokens=True)
|
|
|
labels = tok["input_ids"].to(device)
|
|
|
pad_id = tokenizer.pad_token_id
|
|
|
eos_id = tokenizer.eos_token_id
|
|
|
|
|
|
|
|
|
|
|
|
y_in = torch.full_like(labels, pad_id)
|
|
|
y_in[:, 1:] = labels[:, :-1]
|
|
|
y_out = labels
|
|
|
|
|
|
class EmbeddingDecoder(nn.Module):
|
|
|
"""
|
|
|
Strong conditioning: concat emb each step.
|
|
|
Weight tying: embed.weight = fc.weight.
|
|
|
Deterministic teacher forcing via pre-built y_in (no ratios).
|
|
|
"""
|
|
|
def __init__(self, input_dim, hidden_dim, vocab_size, p=0.2):
|
|
|
super().__init__()
|
|
|
self.bridge = nn.Linear(input_dim, hidden_dim)
|
|
|
self.embed = nn.Embedding(vocab_size, hidden_dim)
|
|
|
self.gru = nn.GRU(hidden_dim + input_dim, hidden_dim, batch_first=True)
|
|
|
self.ln = nn.LayerNorm(hidden_dim)
|
|
|
self.fc = nn.Linear(hidden_dim, vocab_size, bias=True)
|
|
|
self.drop = nn.Dropout(p)
|
|
|
|
|
|
self.fc.weight = self.embed.weight
|
|
|
|
|
|
def forward_teacher_forced(self, emb_vec, in_ids, max_len):
|
|
|
"""
|
|
|
emb_vec: [B,D], in_ids: [B,L] (strict teacher forcing inputs)
|
|
|
Returns logits: [B,L,V]
|
|
|
"""
|
|
|
B, D_in = emb_vec.shape
|
|
|
H0 = torch.tanh(self.bridge(emb_vec)).unsqueeze(0)
|
|
|
logits_all = []
|
|
|
h = H0
|
|
|
for t in range(max_len):
|
|
|
inp = in_ids[:, t].unsqueeze(1)
|
|
|
token_h = self.drop(self.embed(inp))
|
|
|
step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1)
|
|
|
out, h = self.gru(step_in, h)
|
|
|
out = self.ln(out.squeeze(1))
|
|
|
logits = self.fc(self.drop(out))
|
|
|
logits_all.append(logits.unsqueeze(1))
|
|
|
return torch.cat(logits_all, dim=1)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def greedy_decode(self, emb_vec, max_len, start_id, eos_id):
|
|
|
"""
|
|
|
Pure greedy with EOS stop; forbids PAD to reduce loops.
|
|
|
"""
|
|
|
B, _ = emb_vec.shape
|
|
|
h = torch.tanh(self.bridge(emb_vec)).unsqueeze(0)
|
|
|
inp = torch.full((B,1), start_id, dtype=torch.long, device=emb_vec.device)
|
|
|
out_ids = []
|
|
|
done = torch.zeros(B, dtype=torch.bool, device=emb_vec.device)
|
|
|
|
|
|
for _ in range(max_len):
|
|
|
token_h = self.embed(inp)
|
|
|
step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1)
|
|
|
out, h = self.gru(step_in, h)
|
|
|
logits = self.fc(out.squeeze(1))
|
|
|
logits[:, pad_id] = -1e9
|
|
|
next_id = torch.argmax(logits, dim=-1)
|
|
|
out_ids.append(next_id.unsqueeze(1))
|
|
|
done |= (next_id == eos_id)
|
|
|
if done.all(): break
|
|
|
inp = next_id.unsqueeze(1)
|
|
|
return torch.cat(out_ids, dim=1)
|
|
|
|
|
|
decoder = EmbeddingDecoder(D, HIDDEN_DIM, tokenizer.vocab_size).to(device)
|
|
|
opt_dec = optim.Adam(decoder.parameters(), lr=LR_DECODER)
|
|
|
crit_dec = nn.CrossEntropyLoss(ignore_index=pad_id)
|
|
|
|
|
|
print("\nTraining decoder...")
|
|
|
dec_losses = []
|
|
|
steps = (N + BATCH_SIZE_DEC - 1) // BATCH_SIZE_DEC
|
|
|
for ep in range(EPOCHS_DECODER):
|
|
|
perm = torch.randperm(N, device=device)
|
|
|
total = 0.0
|
|
|
for i in range(0, N, BATCH_SIZE_DEC):
|
|
|
idx = perm[i:i+BATCH_SIZE_DEC]
|
|
|
eb = y_embeddings[idx]
|
|
|
yin = y_in[idx]
|
|
|
yout = y_out[idx]
|
|
|
|
|
|
opt_dec.zero_grad()
|
|
|
logits = decoder.forward_teacher_forced(eb, yin, max_len=yout.size(1))
|
|
|
loss = crit_dec(logits.reshape(-1, logits.size(-1)), yout.reshape(-1))
|
|
|
loss.backward()
|
|
|
nn.utils.clip_grad_norm_(decoder.parameters(), 1.0)
|
|
|
opt_dec.step()
|
|
|
total += loss.item()
|
|
|
avg = total / max(1, steps)
|
|
|
dec_losses.append(avg)
|
|
|
print(f"Decoder Epoch {ep+1}/{EPOCHS_DECODER} - Loss: {avg:.4f}")
|
|
|
|
|
|
if PLOT_LOSS:
|
|
|
plt.figure(); plt.plot(dec_losses, marker="o"); plt.title("Decoder Loss"); plt.grid(True); plt.show()
|
|
|
|
|
|
torch.save({"state_dict": decoder.state_dict(), "dim": D, "vocab_size": tokenizer.vocab_size},
|
|
|
"embedding_decoder.pth")
|
|
|
print("Saved decoder -> embedding_decoder.pth")
|
|
|
|
|
|
|
|
|
embedder = SentenceTransformer(MODEL_NAME, device=device)
|
|
|
try:
|
|
|
dim = embedder.get_sentence_embedding_dimension()
|
|
|
if dim != D:
|
|
|
raise RuntimeError(f"Embedder dim {dim} != training dim {D}. Regenerate embeddings with same MODEL_NAME.")
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def generate(text: str, max_len: int = 24) -> str:
|
|
|
|
|
|
x = embedder.encode([text], convert_to_tensor=True, device=device)
|
|
|
|
|
|
y_pred = mapper(x)
|
|
|
|
|
|
ids = decoder.greedy_decode(y_pred, max_len=max_len, start_id=pad_id, eos_id=eos_id)[0].tolist()
|
|
|
return tokenizer.decode(ids, skip_special_tokens=True)
|
|
|
|
|
|
print("\nE2E test:")
|
|
|
inp = "User: Hi"
|
|
|
print(f"{inp} ->", generate(inp))
|
|
|
|