|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
emb_data = torch.load("chat_embeddings.pt")
|
|
|
|
|
|
x_embeddings = emb_data["source"]
|
|
|
y_embeddings = emb_data["target"]
|
|
|
|
|
|
print("Source shape:", x_embeddings.shape)
|
|
|
print("Target shape:", y_embeddings.shape)
|
|
|
|
|
|
embedding_dim = x_embeddings.shape[1]
|
|
|
num_samples = x_embeddings.shape[0]
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"Using device: {device}")
|
|
|
|
|
|
x_embeddings = x_embeddings.to(device)
|
|
|
y_embeddings = y_embeddings.to(device)
|
|
|
|
|
|
|
|
|
class SemanticMapper(nn.Module):
|
|
|
def __init__(self, dim):
|
|
|
super().__init__()
|
|
|
self.net = nn.Sequential(
|
|
|
nn.Linear(dim, dim * 2),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(dim * 2, dim)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.net(x)
|
|
|
|
|
|
model = SemanticMapper(embedding_dim).to(device)
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
|
criterion = nn.CosineEmbeddingLoss()
|
|
|
|
|
|
|
|
|
epochs = 20
|
|
|
batch_size = 64
|
|
|
loss_history = []
|
|
|
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
perm = torch.randperm(num_samples, device=device)
|
|
|
epoch_loss = 0.0
|
|
|
for i in range(0, num_samples, batch_size):
|
|
|
idx = perm[i:i + batch_size]
|
|
|
x_batch = x_embeddings[idx]
|
|
|
y_batch = y_embeddings[idx]
|
|
|
target = torch.ones(x_batch.size(0), device=device)
|
|
|
|
|
|
y_pred = model(x_batch)
|
|
|
loss = criterion(y_pred, y_batch, target)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
epoch_loss += loss.item()
|
|
|
|
|
|
avg_loss = epoch_loss / (num_samples / batch_size)
|
|
|
loss_history.append(avg_loss)
|
|
|
print(f"Epoch {epoch + 1}/{epochs} - Loss: {avg_loss:.6f}")
|
|
|
|
|
|
|
|
|
plt.plot(loss_history, marker="o")
|
|
|
plt.title("Training Loss (Cosine Similarity)")
|
|
|
plt.xlabel("Epoch")
|
|
|
plt.ylabel("Loss")
|
|
|
plt.grid(True)
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), "semantic_mapper.pth") |