File size: 3,960 Bytes
7cd7caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from transformers import T5Tokenizer
from sentence_transformers import SentenceTransformer

# ===== CONFIG =====
INPUT_FILE = "chat_1turn.csv"
EMB_FILE = "chat_embeddings.pt"
MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
EPOCHS = 80
BATCH_SIZE = 16
HIDDEN_DIM = 512
MAX_LEN = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ===== Load CSV =====
df = pd.read_csv(INPUT_FILE)
sources = df["source"].fillna("").tolist()
targets = df["target"].fillna("").tolist()

# ===== Tokenizer =====
tokenizer = T5Tokenizer.from_pretrained("t5-small")
target_enc = tokenizer(targets, padding=True, truncation=True,
                       return_tensors="pt", max_length=MAX_LEN)
input_ids = target_enc["input_ids"].to(device)
attention_mask = target_enc["attention_mask"].to(device)

# ===== Load embeddings =====
emb_data = torch.load(EMB_FILE)
x_embeddings = emb_data["source"].to(device)  # not used directly in this training
y_embeddings = emb_data["target"].to(device)  # used to condition decoder

# ===== Decoder =====
class EmbeddingDecoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, vocab_size):
        super().__init__()
        self.bridge = nn.Linear(input_dim, hidden_dim)
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, emb_vec, target_ids=None, teacher_forcing_ratio=0.5, max_len=MAX_LEN):
        hidden = self.bridge(emb_vec).unsqueeze(0)  # [1,B,H]
        B = emb_vec.size(0)
        outputs = []

        # start with pad_token (T5 has pad=0, eos=1)
        inp = torch.full((B,1), tokenizer.pad_token_id, device=emb_vec.device)

        for t in range(max_len):
            inp_emb = self.embed(inp)  # [B,1,H]
            out, hidden = self.gru(inp_emb, hidden)  # [B,1,H]
            logits = self.fc(out.squeeze(1))  # [B,V]
            outputs.append(logits.unsqueeze(1))

            if target_ids is not None and t < target_ids.size(1) and torch.rand(1).item() < teacher_forcing_ratio:
                inp = target_ids[:, t].unsqueeze(1)
            else:
                inp = torch.argmax(logits, dim=-1, keepdim=True)

        return torch.cat(outputs, dim=1)  # [B, max_len, V]

# ===== Train =====
decoder = EmbeddingDecoder(y_embeddings.shape[1], HIDDEN_DIM, tokenizer.vocab_size).to(device)
optimizer = optim.Adam(decoder.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

print("Training decoder...")
for epoch in range(EPOCHS):
    decoder.train()
    total_loss = 0.0
    for i in range(0, len(y_embeddings), BATCH_SIZE):
        xb = y_embeddings[i:i+BATCH_SIZE]
        yb = input_ids[i:i+BATCH_SIZE]

        optimizer.zero_grad()
        logits = decoder(xb, target_ids=yb, teacher_forcing_ratio=0.7, max_len=yb.size(1))
        loss = criterion(logits.reshape(-1, logits.size(-1)), yb.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}")

# ===== Inference =====
embedder = SentenceTransformer(MODEL_NAME, device=device)

def generate(text, max_len=30, use_mapper=False, mapper=None):
    with torch.no_grad():
        # embed new text
        emb = embedder.encode([text], convert_to_tensor=True, device=device)
        if use_mapper and mapper is not None:
            emb = mapper(emb)
        logits = decoder(emb, target_ids=None, teacher_forcing_ratio=0.0, max_len=max_len)
        ids = torch.argmax(logits, dim=-1).squeeze(0).tolist()
        return tokenizer.decode(ids, skip_special_tokens=True)

# ===== Test =====
print("Hi ->", generate("Hi"))