Upload 8 files
Browse files- decoder.py +103 -0
- formatter.py +77 -0
- main.py +79 -0
- pre_embed.py +55 -0
- test_chat.py +91 -0
- test_embed.py +18 -0
- test_full.py +203 -0
- testcuda.py +6 -0
decoder.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from transformers import T5Tokenizer
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
|
| 8 |
+
# ===== CONFIG =====
|
| 9 |
+
INPUT_FILE = "chat_1turn.csv"
|
| 10 |
+
EMB_FILE = "chat_embeddings.pt"
|
| 11 |
+
MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
| 12 |
+
EPOCHS = 80
|
| 13 |
+
BATCH_SIZE = 16
|
| 14 |
+
HIDDEN_DIM = 512
|
| 15 |
+
MAX_LEN = 64
|
| 16 |
+
|
| 17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
print(f"Using device: {device}")
|
| 19 |
+
|
| 20 |
+
# ===== Load CSV =====
|
| 21 |
+
df = pd.read_csv(INPUT_FILE)
|
| 22 |
+
sources = df["source"].fillna("").tolist()
|
| 23 |
+
targets = df["target"].fillna("").tolist()
|
| 24 |
+
|
| 25 |
+
# ===== Tokenizer =====
|
| 26 |
+
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| 27 |
+
target_enc = tokenizer(targets, padding=True, truncation=True,
|
| 28 |
+
return_tensors="pt", max_length=MAX_LEN)
|
| 29 |
+
input_ids = target_enc["input_ids"].to(device)
|
| 30 |
+
attention_mask = target_enc["attention_mask"].to(device)
|
| 31 |
+
|
| 32 |
+
# ===== Load embeddings =====
|
| 33 |
+
emb_data = torch.load(EMB_FILE)
|
| 34 |
+
x_embeddings = emb_data["source"].to(device) # not used directly in this training
|
| 35 |
+
y_embeddings = emb_data["target"].to(device) # used to condition decoder
|
| 36 |
+
|
| 37 |
+
# ===== Decoder =====
|
| 38 |
+
class EmbeddingDecoder(nn.Module):
|
| 39 |
+
def __init__(self, input_dim, hidden_dim, vocab_size):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.bridge = nn.Linear(input_dim, hidden_dim)
|
| 42 |
+
self.embed = nn.Embedding(vocab_size, hidden_dim)
|
| 43 |
+
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
|
| 44 |
+
self.fc = nn.Linear(hidden_dim, vocab_size)
|
| 45 |
+
|
| 46 |
+
def forward(self, emb_vec, target_ids=None, teacher_forcing_ratio=0.5, max_len=MAX_LEN):
|
| 47 |
+
hidden = self.bridge(emb_vec).unsqueeze(0) # [1,B,H]
|
| 48 |
+
B = emb_vec.size(0)
|
| 49 |
+
outputs = []
|
| 50 |
+
|
| 51 |
+
# start with pad_token (T5 has pad=0, eos=1)
|
| 52 |
+
inp = torch.full((B,1), tokenizer.pad_token_id, device=emb_vec.device)
|
| 53 |
+
|
| 54 |
+
for t in range(max_len):
|
| 55 |
+
inp_emb = self.embed(inp) # [B,1,H]
|
| 56 |
+
out, hidden = self.gru(inp_emb, hidden) # [B,1,H]
|
| 57 |
+
logits = self.fc(out.squeeze(1)) # [B,V]
|
| 58 |
+
outputs.append(logits.unsqueeze(1))
|
| 59 |
+
|
| 60 |
+
if target_ids is not None and t < target_ids.size(1) and torch.rand(1).item() < teacher_forcing_ratio:
|
| 61 |
+
inp = target_ids[:, t].unsqueeze(1)
|
| 62 |
+
else:
|
| 63 |
+
inp = torch.argmax(logits, dim=-1, keepdim=True)
|
| 64 |
+
|
| 65 |
+
return torch.cat(outputs, dim=1) # [B, max_len, V]
|
| 66 |
+
|
| 67 |
+
# ===== Train =====
|
| 68 |
+
decoder = EmbeddingDecoder(y_embeddings.shape[1], HIDDEN_DIM, tokenizer.vocab_size).to(device)
|
| 69 |
+
optimizer = optim.Adam(decoder.parameters(), lr=1e-3)
|
| 70 |
+
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
|
| 71 |
+
|
| 72 |
+
print("Training decoder...")
|
| 73 |
+
for epoch in range(EPOCHS):
|
| 74 |
+
decoder.train()
|
| 75 |
+
total_loss = 0.0
|
| 76 |
+
for i in range(0, len(y_embeddings), BATCH_SIZE):
|
| 77 |
+
xb = y_embeddings[i:i+BATCH_SIZE]
|
| 78 |
+
yb = input_ids[i:i+BATCH_SIZE]
|
| 79 |
+
|
| 80 |
+
optimizer.zero_grad()
|
| 81 |
+
logits = decoder(xb, target_ids=yb, teacher_forcing_ratio=0.7, max_len=yb.size(1))
|
| 82 |
+
loss = criterion(logits.reshape(-1, logits.size(-1)), yb.reshape(-1))
|
| 83 |
+
loss.backward()
|
| 84 |
+
optimizer.step()
|
| 85 |
+
total_loss += loss.item()
|
| 86 |
+
|
| 87 |
+
print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}")
|
| 88 |
+
|
| 89 |
+
# ===== Inference =====
|
| 90 |
+
embedder = SentenceTransformer(MODEL_NAME, device=device)
|
| 91 |
+
|
| 92 |
+
def generate(text, max_len=30, use_mapper=False, mapper=None):
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
# embed new text
|
| 95 |
+
emb = embedder.encode([text], convert_to_tensor=True, device=device)
|
| 96 |
+
if use_mapper and mapper is not None:
|
| 97 |
+
emb = mapper(emb)
|
| 98 |
+
logits = decoder(emb, target_ids=None, teacher_forcing_ratio=0.0, max_len=max_len)
|
| 99 |
+
ids = torch.argmax(logits, dim=-1).squeeze(0).tolist()
|
| 100 |
+
return tokenizer.decode(ids, skip_special_tokens=True)
|
| 101 |
+
|
| 102 |
+
# ===== Test =====
|
| 103 |
+
print("Hi ->", generate("Hi"))
|
formatter.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
from transformers import T5Tokenizer
|
| 3 |
+
import pandas as pd, csv, re
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
# ── Config ────────────────────────────────────────────────────────────────
|
| 7 |
+
jsonl_path = "lmsys_chat_1m_full.jsonl" # local file
|
| 8 |
+
use_subset = False # False ⇒ full 1 M rows
|
| 9 |
+
num_samples = 500 # if subset
|
| 10 |
+
max_turn_pairs = 1 # 4 user+assistant = 8 lines
|
| 11 |
+
max_input_tokens = 512 # fits t5-small/base
|
| 12 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 13 |
+
|
| 14 |
+
tok = T5Tokenizer.from_pretrained("t5-small")
|
| 15 |
+
ds = load_dataset("json", data_files=jsonl_path, split="train")
|
| 16 |
+
|
| 17 |
+
if use_subset:
|
| 18 |
+
ds = ds.select(range(min(num_samples, len(ds))))
|
| 19 |
+
print(f"🔍 subset → {len(ds)} rows")
|
| 20 |
+
|
| 21 |
+
def mostly_ascii(s: str, threshold: float = .3) -> bool:
|
| 22 |
+
try:
|
| 23 |
+
return sum(ord(ch) > 127 for ch in s) / len(s) < threshold
|
| 24 |
+
except ZeroDivisionError:
|
| 25 |
+
return False
|
| 26 |
+
|
| 27 |
+
def format_turns(conv):
|
| 28 |
+
return [f"{m['role'].capitalize()}: {m['content'].strip()}" for m in conv]
|
| 29 |
+
|
| 30 |
+
def build_pair(turns, max_tokens=512):
|
| 31 |
+
if len(turns) < max_turn_pairs * 2:
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
# last N pairs
|
| 35 |
+
use_turns = turns[-(max_turn_pairs * 2):]
|
| 36 |
+
|
| 37 |
+
prompt = "chat:\n\n" + "\n\n".join(use_turns[:-1])
|
| 38 |
+
target = use_turns[-1].replace("Assistant: ", "", 1)
|
| 39 |
+
|
| 40 |
+
# --- safe trimming loop --------------------------------------------
|
| 41 |
+
for _ in range(max_turn_pairs): # at most 4 trims if max_turn_pairs=4
|
| 42 |
+
if len(tok.tokenize(prompt)) <= max_tokens:
|
| 43 |
+
break # fits → good
|
| 44 |
+
sep_pos = prompt.find("\n\n", len("chat:\n\n"))
|
| 45 |
+
if sep_pos == -1: # no more turns to drop
|
| 46 |
+
return None
|
| 47 |
+
prompt = "chat:\n\n" + prompt[sep_pos + 2:]
|
| 48 |
+
else:
|
| 49 |
+
# still too long after all trims
|
| 50 |
+
return None
|
| 51 |
+
# -------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
if len(prompt) < 30 or len(target) < 10:
|
| 54 |
+
return None
|
| 55 |
+
if not mostly_ascii(prompt + target):
|
| 56 |
+
return None
|
| 57 |
+
return prompt, target
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
rows, kept = [], 0
|
| 61 |
+
for ex in tqdm(ds, desc="formatting"):
|
| 62 |
+
conv = ex.get("conversation")
|
| 63 |
+
if not isinstance(conv, list): continue
|
| 64 |
+
p = build_pair(format_turns(conv))
|
| 65 |
+
if p:
|
| 66 |
+
rows.append({"source": p[0], "target": p[1]})
|
| 67 |
+
kept += 1
|
| 68 |
+
|
| 69 |
+
print(f"✅ kept {kept} examples")
|
| 70 |
+
|
| 71 |
+
pd.DataFrame(rows).to_csv(
|
| 72 |
+
"chat_1turn.csv",
|
| 73 |
+
index=False,
|
| 74 |
+
quoting=csv.QUOTE_ALL, # preserves embedded newlines
|
| 75 |
+
encoding="utf-8"
|
| 76 |
+
)
|
| 77 |
+
print("💾 saved → t5_chat_4turn.csv")
|
main.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
# ===== Load precomputed embeddings =====
|
| 7 |
+
emb_data = torch.load("chat_embeddings.pt")
|
| 8 |
+
|
| 9 |
+
x_embeddings = emb_data["source"] # [N, D]
|
| 10 |
+
y_embeddings = emb_data["target"] # [N, D]
|
| 11 |
+
|
| 12 |
+
print("Source shape:", x_embeddings.shape)
|
| 13 |
+
print("Target shape:", y_embeddings.shape)
|
| 14 |
+
|
| 15 |
+
embedding_dim = x_embeddings.shape[1]
|
| 16 |
+
num_samples = x_embeddings.shape[0]
|
| 17 |
+
|
| 18 |
+
# ===== Device =====
|
| 19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
print(f"Using device: {device}")
|
| 21 |
+
|
| 22 |
+
x_embeddings = x_embeddings.to(device)
|
| 23 |
+
y_embeddings = y_embeddings.to(device)
|
| 24 |
+
|
| 25 |
+
# ===== Define model =====
|
| 26 |
+
class SemanticMapper(nn.Module):
|
| 27 |
+
def __init__(self, dim):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.net = nn.Sequential(
|
| 30 |
+
nn.Linear(dim, dim * 2),
|
| 31 |
+
nn.ReLU(),
|
| 32 |
+
nn.Linear(dim * 2, dim)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
return self.net(x)
|
| 37 |
+
|
| 38 |
+
model = SemanticMapper(embedding_dim).to(device)
|
| 39 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
| 40 |
+
criterion = nn.CosineEmbeddingLoss()
|
| 41 |
+
|
| 42 |
+
# ===== Training config =====
|
| 43 |
+
epochs = 20
|
| 44 |
+
batch_size = 64
|
| 45 |
+
loss_history = []
|
| 46 |
+
|
| 47 |
+
# ===== Training loop =====
|
| 48 |
+
for epoch in range(epochs):
|
| 49 |
+
perm = torch.randperm(num_samples, device=device)
|
| 50 |
+
epoch_loss = 0.0
|
| 51 |
+
for i in range(0, num_samples, batch_size):
|
| 52 |
+
idx = perm[i:i + batch_size]
|
| 53 |
+
x_batch = x_embeddings[idx]
|
| 54 |
+
y_batch = y_embeddings[idx]
|
| 55 |
+
target = torch.ones(x_batch.size(0), device=device) # cosine target = +1
|
| 56 |
+
|
| 57 |
+
y_pred = model(x_batch)
|
| 58 |
+
loss = criterion(y_pred, y_batch, target)
|
| 59 |
+
|
| 60 |
+
optimizer.zero_grad()
|
| 61 |
+
loss.backward()
|
| 62 |
+
optimizer.step()
|
| 63 |
+
|
| 64 |
+
epoch_loss += loss.item()
|
| 65 |
+
|
| 66 |
+
avg_loss = epoch_loss / (num_samples / batch_size)
|
| 67 |
+
loss_history.append(avg_loss)
|
| 68 |
+
print(f"Epoch {epoch + 1}/{epochs} - Loss: {avg_loss:.6f}")
|
| 69 |
+
|
| 70 |
+
# ===== Plot loss curve =====
|
| 71 |
+
plt.plot(loss_history, marker="o")
|
| 72 |
+
plt.title("Training Loss (Cosine Similarity)")
|
| 73 |
+
plt.xlabel("Epoch")
|
| 74 |
+
plt.ylabel("Loss")
|
| 75 |
+
plt.grid(True)
|
| 76 |
+
plt.show()
|
| 77 |
+
|
| 78 |
+
# Save the trained model
|
| 79 |
+
torch.save(model.state_dict(), "semantic_mapper.pth")
|
pre_embed.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# CONFIGURATION
|
| 8 |
+
INPUT_FILE = "chat_1turn.csv"
|
| 9 |
+
OUTPUT_FILE = "chat_embeddings.pt"
|
| 10 |
+
MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
| 11 |
+
BATCH_SIZE = 128 # Go big or go slow
|
| 12 |
+
USE_GPU = torch.cuda.is_available()
|
| 13 |
+
MAX_ROWS = 2000 # Set to e.g. 1000 for quick dev tests
|
| 14 |
+
|
| 15 |
+
# 🔧 Sanity checks
|
| 16 |
+
assert os.path.exists(INPUT_FILE), f"❌ File not found: {INPUT_FILE}"
|
| 17 |
+
|
| 18 |
+
# 🚀 Load model
|
| 19 |
+
print(f"🧠 Loading model: {MODEL_NAME} {'[GPU]' if USE_GPU else '[CPU]'}")
|
| 20 |
+
model = SentenceTransformer(MODEL_NAME, device="cuda" if USE_GPU else "cpu")
|
| 21 |
+
|
| 22 |
+
# 📂 Load data
|
| 23 |
+
print("📂 Reading CSV...")
|
| 24 |
+
df = pd.read_csv(INPUT_FILE)
|
| 25 |
+
assert 'source' in df.columns and 'target' in df.columns, "❌ Missing 'source' or 'target' column!"
|
| 26 |
+
|
| 27 |
+
if MAX_ROWS:
|
| 28 |
+
df = df.head(MAX_ROWS)
|
| 29 |
+
|
| 30 |
+
sources = df['source'].fillna("").tolist()
|
| 31 |
+
targets = df['target'].fillna("").tolist()
|
| 32 |
+
|
| 33 |
+
# ⏱️ Embed all at once
|
| 34 |
+
def embed_all(texts, label):
|
| 35 |
+
print(f"⚙️ Embedding {label} ({len(texts)} items)...")
|
| 36 |
+
start = time.time()
|
| 37 |
+
embeddings = model.encode(
|
| 38 |
+
texts,
|
| 39 |
+
batch_size=BATCH_SIZE,
|
| 40 |
+
convert_to_tensor=True,
|
| 41 |
+
normalize_embeddings=True,
|
| 42 |
+
show_progress_bar=True,
|
| 43 |
+
device="cuda" if USE_GPU else "cpu",
|
| 44 |
+
torch_dtype=torch.int8
|
| 45 |
+
)
|
| 46 |
+
print(f"✅ {label} embedding done in {time.time() - start:.2f}s")
|
| 47 |
+
return embeddings
|
| 48 |
+
|
| 49 |
+
source_tensor = embed_all(sources, "source")
|
| 50 |
+
target_tensor = embed_all(targets, "target")
|
| 51 |
+
|
| 52 |
+
# 💾 Save
|
| 53 |
+
print(f"💾 Saving to {OUTPUT_FILE}...")
|
| 54 |
+
torch.save({"source": source_tensor, "target": target_tensor}, OUTPUT_FILE)
|
| 55 |
+
print(f"✅ Saved {len(sources)} embeddings to {OUTPUT_FILE}")
|
test_chat.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# chat.py - use trained mapper + decoder interactively
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import T5Tokenizer
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
# ===== CONFIG =====
|
| 10 |
+
MAPPER_PTH = "semantic_mapper.pth"
|
| 11 |
+
DECODER_PTH = "embedding_decoder.pth"
|
| 12 |
+
MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
| 13 |
+
MAX_LEN = 4096
|
| 14 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
|
| 16 |
+
# ===== LOAD TOKENIZER =====
|
| 17 |
+
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| 18 |
+
pad_id = tokenizer.pad_token_id
|
| 19 |
+
eos_id = tokenizer.eos_token_id
|
| 20 |
+
|
| 21 |
+
# ===== MODEL CLASSES (same defs as training) =====
|
| 22 |
+
class SemanticMapper(torch.nn.Module):
|
| 23 |
+
def __init__(self, dim):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.net = torch.nn.Sequential(
|
| 26 |
+
torch.nn.Linear(dim, dim * 2),
|
| 27 |
+
torch.nn.ReLU(),
|
| 28 |
+
torch.nn.Linear(dim * 2, dim)
|
| 29 |
+
)
|
| 30 |
+
def forward(self, x): return self.net(x)
|
| 31 |
+
|
| 32 |
+
class EmbeddingDecoder(nn.Module):
|
| 33 |
+
def __init__(self, input_dim, hidden_dim, vocab_size, p=0.2):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.bridge = nn.Linear(input_dim, hidden_dim) # emb -> h0
|
| 36 |
+
self.embed = nn.Embedding(vocab_size, hidden_dim) # token -> hidden
|
| 37 |
+
self.gru = nn.GRU(hidden_dim + input_dim, hidden_dim, batch_first=True)
|
| 38 |
+
self.ln = nn.LayerNorm(hidden_dim)
|
| 39 |
+
self.fc = nn.Linear(hidden_dim, vocab_size, bias=True)
|
| 40 |
+
self.drop = nn.Dropout(p)
|
| 41 |
+
# tie weights
|
| 42 |
+
self.fc.weight = self.embed.weight
|
| 43 |
+
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def greedy_decode(self, emb_vec, max_len, start_id, eos_id):
|
| 46 |
+
B, _ = emb_vec.shape
|
| 47 |
+
h = torch.tanh(self.bridge(emb_vec)).unsqueeze(0)
|
| 48 |
+
inp = torch.full((B,1), start_id, dtype=torch.long, device=emb_vec.device)
|
| 49 |
+
out_ids = []
|
| 50 |
+
for _ in range(max_len):
|
| 51 |
+
token_h = self.drop(self.embed(inp)) # [B,1,H]
|
| 52 |
+
step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1)
|
| 53 |
+
out, h = self.gru(step_in, h)
|
| 54 |
+
out = self.ln(out.squeeze(1))
|
| 55 |
+
logits = self.fc(self.drop(out))
|
| 56 |
+
logits[:, pad_id] = -1e9
|
| 57 |
+
next_id = torch.argmax(logits, dim=-1)
|
| 58 |
+
out_ids.append(next_id.unsqueeze(1))
|
| 59 |
+
if (next_id == eos_id).all(): break
|
| 60 |
+
inp = next_id.unsqueeze(1)
|
| 61 |
+
return torch.cat(out_ids, dim=1)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ===== LOAD MODELS =====
|
| 65 |
+
mapper_ckpt = torch.load(MAPPER_PTH, map_location=DEVICE)
|
| 66 |
+
mapper = SemanticMapper(mapper_ckpt["dim"]).to(DEVICE)
|
| 67 |
+
mapper.load_state_dict(mapper_ckpt["state_dict"])
|
| 68 |
+
mapper.eval()
|
| 69 |
+
|
| 70 |
+
dec_ckpt = torch.load(DECODER_PTH, map_location=DEVICE)
|
| 71 |
+
decoder = EmbeddingDecoder(dec_ckpt["dim"], 512, dec_ckpt["vocab_size"]).to(DEVICE)
|
| 72 |
+
decoder.load_state_dict(dec_ckpt["state_dict"])
|
| 73 |
+
decoder.eval()
|
| 74 |
+
|
| 75 |
+
embedder = SentenceTransformer(MODEL_NAME, device=DEVICE)
|
| 76 |
+
|
| 77 |
+
# ===== CHAT LOOP =====
|
| 78 |
+
def chat():
|
| 79 |
+
print("Chat ready. Type 'quit' to exit.")
|
| 80 |
+
while True:
|
| 81 |
+
user = input("User: ").strip()
|
| 82 |
+
if not user or user.lower() in {"quit","exit"}: break
|
| 83 |
+
x = embedder.encode([user], convert_to_tensor=True, device=DEVICE).detach().clone()
|
| 84 |
+
y_pred = mapper(x)
|
| 85 |
+
ids = decoder.greedy_decode(y_pred, max_len=MAX_LEN,
|
| 86 |
+
start_id=pad_id, eos_id=eos_id)[0].tolist()
|
| 87 |
+
reply = tokenizer.decode(ids, skip_special_tokens=True)
|
| 88 |
+
print("Bot:", reply)
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
chat()
|
test_embed.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import SentenceTransformer
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
model = SentenceTransformer("Snowflake/snowflake-arctic-embed-xs", device="cuda")
|
| 6 |
+
|
| 7 |
+
texts = ["The quick brown fox jumps over the lazy dog."] * 1000
|
| 8 |
+
|
| 9 |
+
start = time.time()
|
| 10 |
+
embeddings = model.encode(
|
| 11 |
+
texts,
|
| 12 |
+
batch_size=512,
|
| 13 |
+
convert_to_tensor=True,
|
| 14 |
+
normalize_embeddings=True,
|
| 15 |
+
device="cuda",
|
| 16 |
+
torch_dtype=torch.int8
|
| 17 |
+
)
|
| 18 |
+
print(f"⏱️ Embedded 1000 items in {time.time() - start:.2f} seconds")
|
test_full.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import torch, torch.nn as nn, torch.optim as optim
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from transformers import T5Tokenizer
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
|
| 8 |
+
# ==== Config ====
|
| 9 |
+
EMB_FILE = "chat_embeddings.pt" # {"source": [N,D], "target": [N,D]}
|
| 10 |
+
CSV_FILE = "chat_1turn.csv" # columns: source, target
|
| 11 |
+
MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
| 12 |
+
EPOCHS_MAPPER = 20
|
| 13 |
+
EPOCHS_DECODER = 160
|
| 14 |
+
BATCH_SIZE_MAP = 64
|
| 15 |
+
BATCH_SIZE_DEC = 64
|
| 16 |
+
LR_MAPPER = 1e-3
|
| 17 |
+
LR_DECODER = 1e-3
|
| 18 |
+
HIDDEN_DIM = 512
|
| 19 |
+
MAX_LEN = 64
|
| 20 |
+
PLOT_LOSS = False
|
| 21 |
+
|
| 22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
print(f"Using device: {device}")
|
| 24 |
+
|
| 25 |
+
# ==== Load embeddings & CSV ====
|
| 26 |
+
emb = torch.load(EMB_FILE, map_location=device)
|
| 27 |
+
x_embeddings = emb["source"].to(device) # [N,D]
|
| 28 |
+
y_embeddings = emb["target"].to(device) # [N,D]
|
| 29 |
+
N, D = x_embeddings.shape
|
| 30 |
+
print(f"Loaded embeddings: N={N}, D={D}")
|
| 31 |
+
|
| 32 |
+
df = pd.read_csv(CSV_FILE)
|
| 33 |
+
assert "target" in df.columns
|
| 34 |
+
targets = df["target"].fillna("").tolist()
|
| 35 |
+
|
| 36 |
+
# ==== Mapper: x_emb -> y_emb ====
|
| 37 |
+
class SemanticMapper(nn.Module):
|
| 38 |
+
def __init__(self, dim):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.net = nn.Sequential(
|
| 41 |
+
nn.Linear(dim, dim*2), nn.ReLU(),
|
| 42 |
+
nn.Linear(dim*2, dim)
|
| 43 |
+
)
|
| 44 |
+
def forward(self, x): return self.net(x)
|
| 45 |
+
|
| 46 |
+
mapper = SemanticMapper(D).to(device)
|
| 47 |
+
opt_map = optim.Adam(mapper.parameters(), lr=LR_MAPPER)
|
| 48 |
+
crit_map = nn.CosineEmbeddingLoss()
|
| 49 |
+
|
| 50 |
+
print("\nTraining mapper...")
|
| 51 |
+
map_losses = []
|
| 52 |
+
for ep in range(EPOCHS_MAPPER):
|
| 53 |
+
perm = torch.randperm(N, device=device)
|
| 54 |
+
total = 0.0; steps = 0
|
| 55 |
+
for i in range(0, N, BATCH_SIZE_MAP):
|
| 56 |
+
idx = perm[i:i+BATCH_SIZE_MAP]
|
| 57 |
+
xb, yb = x_embeddings[idx], y_embeddings[idx]
|
| 58 |
+
tgt = torch.ones(xb.size(0), device=device)
|
| 59 |
+
pred = mapper(xb)
|
| 60 |
+
loss = crit_map(pred, yb, tgt)
|
| 61 |
+
opt_map.zero_grad(); loss.backward()
|
| 62 |
+
opt_map.step()
|
| 63 |
+
total += loss.item(); steps += 1
|
| 64 |
+
avg = total / max(1, steps)
|
| 65 |
+
map_losses.append(avg)
|
| 66 |
+
print(f"Mapper Epoch {ep+1}/{EPOCHS_MAPPER} - Loss: {avg:.6f}")
|
| 67 |
+
|
| 68 |
+
if PLOT_LOSS:
|
| 69 |
+
plt.figure(); plt.plot(map_losses, marker="o"); plt.title("Mapper Loss"); plt.grid(True); plt.show()
|
| 70 |
+
|
| 71 |
+
torch.save({"state_dict": mapper.state_dict(), "dim": D}, "semantic_mapper.pth")
|
| 72 |
+
print("Saved mapper -> semantic_mapper.pth")
|
| 73 |
+
|
| 74 |
+
# ==== Decoder: y_emb -> target text ====
|
| 75 |
+
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| 76 |
+
tok = tokenizer(targets, padding=True, truncation=True, max_length=MAX_LEN,
|
| 77 |
+
return_tensors="pt", add_special_tokens=True)
|
| 78 |
+
labels = tok["input_ids"].to(device) # [N,L]
|
| 79 |
+
pad_id = tokenizer.pad_token_id
|
| 80 |
+
eos_id = tokenizer.eos_token_id # T5 uses </s> as EOS
|
| 81 |
+
|
| 82 |
+
# Build shifted inputs for strict teacher forcing:
|
| 83 |
+
# y_in[0] = BOS (use pad_id for T5), then y_in[t] = labels[t-1]
|
| 84 |
+
y_in = torch.full_like(labels, pad_id)
|
| 85 |
+
y_in[:, 1:] = labels[:, :-1]
|
| 86 |
+
y_out = labels # predict labels[t] given y_in[t]
|
| 87 |
+
|
| 88 |
+
class EmbeddingDecoder(nn.Module):
|
| 89 |
+
"""
|
| 90 |
+
Strong conditioning: concat emb each step.
|
| 91 |
+
Weight tying: embed.weight = fc.weight.
|
| 92 |
+
Deterministic teacher forcing via pre-built y_in (no ratios).
|
| 93 |
+
"""
|
| 94 |
+
def __init__(self, input_dim, hidden_dim, vocab_size, p=0.2):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.bridge = nn.Linear(input_dim, hidden_dim) # emb -> h0
|
| 97 |
+
self.embed = nn.Embedding(vocab_size, hidden_dim) # token -> hidden
|
| 98 |
+
self.gru = nn.GRU(hidden_dim + input_dim, hidden_dim, batch_first=True)
|
| 99 |
+
self.ln = nn.LayerNorm(hidden_dim)
|
| 100 |
+
self.fc = nn.Linear(hidden_dim, vocab_size, bias=True)
|
| 101 |
+
self.drop = nn.Dropout(p)
|
| 102 |
+
# Tie weights
|
| 103 |
+
self.fc.weight = self.embed.weight
|
| 104 |
+
|
| 105 |
+
def forward_teacher_forced(self, emb_vec, in_ids, max_len):
|
| 106 |
+
"""
|
| 107 |
+
emb_vec: [B,D], in_ids: [B,L] (strict teacher forcing inputs)
|
| 108 |
+
Returns logits: [B,L,V]
|
| 109 |
+
"""
|
| 110 |
+
B, D_in = emb_vec.shape
|
| 111 |
+
H0 = torch.tanh(self.bridge(emb_vec)).unsqueeze(0) # [1,B,H]
|
| 112 |
+
logits_all = []
|
| 113 |
+
h = H0
|
| 114 |
+
for t in range(max_len):
|
| 115 |
+
inp = in_ids[:, t].unsqueeze(1) # [B,1]
|
| 116 |
+
token_h = self.drop(self.embed(inp)) # [B,1,H]
|
| 117 |
+
step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1) # [B,1,H+D]
|
| 118 |
+
out, h = self.gru(step_in, h) # [B,1,H]
|
| 119 |
+
out = self.ln(out.squeeze(1)) # [B,H]
|
| 120 |
+
logits = self.fc(self.drop(out)) # [B,V]
|
| 121 |
+
logits_all.append(logits.unsqueeze(1))
|
| 122 |
+
return torch.cat(logits_all, dim=1) # [B,L,V]
|
| 123 |
+
|
| 124 |
+
@torch.no_grad()
|
| 125 |
+
def greedy_decode(self, emb_vec, max_len, start_id, eos_id):
|
| 126 |
+
"""
|
| 127 |
+
Pure greedy with EOS stop; forbids PAD to reduce loops.
|
| 128 |
+
"""
|
| 129 |
+
B, _ = emb_vec.shape
|
| 130 |
+
h = torch.tanh(self.bridge(emb_vec)).unsqueeze(0)
|
| 131 |
+
inp = torch.full((B,1), start_id, dtype=torch.long, device=emb_vec.device)
|
| 132 |
+
out_ids = []
|
| 133 |
+
done = torch.zeros(B, dtype=torch.bool, device=emb_vec.device)
|
| 134 |
+
|
| 135 |
+
for _ in range(max_len):
|
| 136 |
+
token_h = self.embed(inp) # [B,1,H]
|
| 137 |
+
step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1)
|
| 138 |
+
out, h = self.gru(step_in, h)
|
| 139 |
+
logits = self.fc(out.squeeze(1)) # [B,V]
|
| 140 |
+
logits[:, pad_id] = -1e9 # discourage PAD
|
| 141 |
+
next_id = torch.argmax(logits, dim=-1) # [B]
|
| 142 |
+
out_ids.append(next_id.unsqueeze(1))
|
| 143 |
+
done |= (next_id == eos_id)
|
| 144 |
+
if done.all(): break
|
| 145 |
+
inp = next_id.unsqueeze(1)
|
| 146 |
+
return torch.cat(out_ids, dim=1) # [B,T]
|
| 147 |
+
|
| 148 |
+
decoder = EmbeddingDecoder(D, HIDDEN_DIM, tokenizer.vocab_size).to(device)
|
| 149 |
+
opt_dec = optim.Adam(decoder.parameters(), lr=LR_DECODER)
|
| 150 |
+
crit_dec = nn.CrossEntropyLoss(ignore_index=pad_id) # no smoothing (small N)
|
| 151 |
+
|
| 152 |
+
print("\nTraining decoder...")
|
| 153 |
+
dec_losses = []
|
| 154 |
+
steps = (N + BATCH_SIZE_DEC - 1) // BATCH_SIZE_DEC
|
| 155 |
+
for ep in range(EPOCHS_DECODER):
|
| 156 |
+
perm = torch.randperm(N, device=device)
|
| 157 |
+
total = 0.0
|
| 158 |
+
for i in range(0, N, BATCH_SIZE_DEC):
|
| 159 |
+
idx = perm[i:i+BATCH_SIZE_DEC]
|
| 160 |
+
eb = y_embeddings[idx] # condition on TRUE target-space embeddings
|
| 161 |
+
yin = y_in[idx] # shifted inputs
|
| 162 |
+
yout = y_out[idx] # labels
|
| 163 |
+
|
| 164 |
+
opt_dec.zero_grad()
|
| 165 |
+
logits = decoder.forward_teacher_forced(eb, yin, max_len=yout.size(1)) # [B,L,V]
|
| 166 |
+
loss = crit_dec(logits.reshape(-1, logits.size(-1)), yout.reshape(-1))
|
| 167 |
+
loss.backward()
|
| 168 |
+
nn.utils.clip_grad_norm_(decoder.parameters(), 1.0)
|
| 169 |
+
opt_dec.step()
|
| 170 |
+
total += loss.item()
|
| 171 |
+
avg = total / max(1, steps)
|
| 172 |
+
dec_losses.append(avg)
|
| 173 |
+
print(f"Decoder Epoch {ep+1}/{EPOCHS_DECODER} - Loss: {avg:.4f}")
|
| 174 |
+
|
| 175 |
+
if PLOT_LOSS:
|
| 176 |
+
plt.figure(); plt.plot(dec_losses, marker="o"); plt.title("Decoder Loss"); plt.grid(True); plt.show()
|
| 177 |
+
|
| 178 |
+
torch.save({"state_dict": decoder.state_dict(), "dim": D, "vocab_size": tokenizer.vocab_size},
|
| 179 |
+
"embedding_decoder.pth")
|
| 180 |
+
print("Saved decoder -> embedding_decoder.pth")
|
| 181 |
+
|
| 182 |
+
# ==== E2E inference ====
|
| 183 |
+
embedder = SentenceTransformer(MODEL_NAME, device=device)
|
| 184 |
+
try:
|
| 185 |
+
dim = embedder.get_sentence_embedding_dimension()
|
| 186 |
+
if dim != D:
|
| 187 |
+
raise RuntimeError(f"Embedder dim {dim} != training dim {D}. Regenerate embeddings with same MODEL_NAME.")
|
| 188 |
+
except Exception:
|
| 189 |
+
pass
|
| 190 |
+
|
| 191 |
+
@torch.no_grad()
|
| 192 |
+
def generate(text: str, max_len: int = 24) -> str:
|
| 193 |
+
# source -> x_emb
|
| 194 |
+
x = embedder.encode([text], convert_to_tensor=True, device=device) # [1,D]
|
| 195 |
+
# map -> y_emb
|
| 196 |
+
y_pred = mapper(x) # [1,D]
|
| 197 |
+
# decode y_emb -> text
|
| 198 |
+
ids = decoder.greedy_decode(y_pred, max_len=max_len, start_id=pad_id, eos_id=eos_id)[0].tolist()
|
| 199 |
+
return tokenizer.decode(ids, skip_special_tokens=True)
|
| 200 |
+
|
| 201 |
+
print("\nE2E test:")
|
| 202 |
+
inp = "User: Hi"
|
| 203 |
+
print(f"{inp} ->", generate(inp))
|
testcuda.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
print(torch.cuda.is_available())
|
| 4 |
+
print(torch.cuda.device_count())
|
| 5 |
+
print(torch.cuda.current_device())
|
| 6 |
+
print(torch.cuda.get_device_name(0))
|