import torch import torch.nn as nn import torch.optim as optim import pandas as pd from transformers import T5Tokenizer from sentence_transformers import SentenceTransformer # ===== CONFIG ===== INPUT_FILE = "chat_1turn.csv" EMB_FILE = "chat_embeddings.pt" MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0" EPOCHS = 80 BATCH_SIZE = 16 HIDDEN_DIM = 512 MAX_LEN = 64 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # ===== Load CSV ===== df = pd.read_csv(INPUT_FILE) sources = df["source"].fillna("").tolist() targets = df["target"].fillna("").tolist() # ===== Tokenizer ===== tokenizer = T5Tokenizer.from_pretrained("t5-small") target_enc = tokenizer(targets, padding=True, truncation=True, return_tensors="pt", max_length=MAX_LEN) input_ids = target_enc["input_ids"].to(device) attention_mask = target_enc["attention_mask"].to(device) # ===== Load embeddings ===== emb_data = torch.load(EMB_FILE) x_embeddings = emb_data["source"].to(device) # not used directly in this training y_embeddings = emb_data["target"].to(device) # used to condition decoder # ===== Decoder ===== class EmbeddingDecoder(nn.Module): def __init__(self, input_dim, hidden_dim, vocab_size): super().__init__() self.bridge = nn.Linear(input_dim, hidden_dim) self.embed = nn.Embedding(vocab_size, hidden_dim) self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True) self.fc = nn.Linear(hidden_dim, vocab_size) def forward(self, emb_vec, target_ids=None, teacher_forcing_ratio=0.5, max_len=MAX_LEN): hidden = self.bridge(emb_vec).unsqueeze(0) # [1,B,H] B = emb_vec.size(0) outputs = [] # start with pad_token (T5 has pad=0, eos=1) inp = torch.full((B,1), tokenizer.pad_token_id, device=emb_vec.device) for t in range(max_len): inp_emb = self.embed(inp) # [B,1,H] out, hidden = self.gru(inp_emb, hidden) # [B,1,H] logits = self.fc(out.squeeze(1)) # [B,V] outputs.append(logits.unsqueeze(1)) if target_ids is not None and t < target_ids.size(1) and torch.rand(1).item() < teacher_forcing_ratio: inp = target_ids[:, t].unsqueeze(1) else: inp = torch.argmax(logits, dim=-1, keepdim=True) return torch.cat(outputs, dim=1) # [B, max_len, V] # ===== Train ===== decoder = EmbeddingDecoder(y_embeddings.shape[1], HIDDEN_DIM, tokenizer.vocab_size).to(device) optimizer = optim.Adam(decoder.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) print("Training decoder...") for epoch in range(EPOCHS): decoder.train() total_loss = 0.0 for i in range(0, len(y_embeddings), BATCH_SIZE): xb = y_embeddings[i:i+BATCH_SIZE] yb = input_ids[i:i+BATCH_SIZE] optimizer.zero_grad() logits = decoder(xb, target_ids=yb, teacher_forcing_ratio=0.7, max_len=yb.size(1)) loss = criterion(logits.reshape(-1, logits.size(-1)), yb.reshape(-1)) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}") # ===== Inference ===== embedder = SentenceTransformer(MODEL_NAME, device=device) def generate(text, max_len=30, use_mapper=False, mapper=None): with torch.no_grad(): # embed new text emb = embedder.encode([text], convert_to_tensor=True, device=device) if use_mapper and mapper is not None: emb = mapper(emb) logits = decoder(emb, target_ids=None, teacher_forcing_ratio=0.0, max_len=max_len) ids = torch.argmax(logits, dim=-1).squeeze(0).tolist() return tokenizer.decode(ids, skip_special_tokens=True) # ===== Test ===== print("Hi ->", generate("Hi"))