|
|
import torch
|
|
|
import pandas as pd
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
import time
|
|
|
import os
|
|
|
|
|
|
|
|
|
INPUT_FILE = "chat_1turn.csv"
|
|
|
OUTPUT_FILE = "chat_embeddings.pt"
|
|
|
MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
|
|
BATCH_SIZE = 128
|
|
|
USE_GPU = torch.cuda.is_available()
|
|
|
MAX_ROWS = 2000
|
|
|
|
|
|
|
|
|
assert os.path.exists(INPUT_FILE), f"β File not found: {INPUT_FILE}"
|
|
|
|
|
|
|
|
|
print(f"π§ Loading model: {MODEL_NAME} {'[GPU]' if USE_GPU else '[CPU]'}")
|
|
|
model = SentenceTransformer(MODEL_NAME, device="cuda" if USE_GPU else "cpu")
|
|
|
|
|
|
|
|
|
print("π Reading CSV...")
|
|
|
df = pd.read_csv(INPUT_FILE)
|
|
|
assert 'source' in df.columns and 'target' in df.columns, "β Missing 'source' or 'target' column!"
|
|
|
|
|
|
if MAX_ROWS:
|
|
|
df = df.head(MAX_ROWS)
|
|
|
|
|
|
sources = df['source'].fillna("").tolist()
|
|
|
targets = df['target'].fillna("").tolist()
|
|
|
|
|
|
|
|
|
def embed_all(texts, label):
|
|
|
print(f"βοΈ Embedding {label} ({len(texts)} items)...")
|
|
|
start = time.time()
|
|
|
embeddings = model.encode(
|
|
|
texts,
|
|
|
batch_size=BATCH_SIZE,
|
|
|
convert_to_tensor=True,
|
|
|
normalize_embeddings=True,
|
|
|
show_progress_bar=True,
|
|
|
device="cuda" if USE_GPU else "cpu",
|
|
|
torch_dtype=torch.int8
|
|
|
)
|
|
|
print(f"β
{label} embedding done in {time.time() - start:.2f}s")
|
|
|
return embeddings
|
|
|
|
|
|
source_tensor = embed_all(sources, "source")
|
|
|
target_tensor = embed_all(targets, "target")
|
|
|
|
|
|
|
|
|
print(f"πΎ Saving to {OUTPUT_FILE}...")
|
|
|
torch.save({"source": source_tensor, "target": target_tensor}, OUTPUT_FILE)
|
|
|
print(f"β
Saved {len(sources)} embeddings to {OUTPUT_FILE}")
|
|
|
|