|
|
import os, json, random, numpy as np, torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import IterableDataset, DataLoader |
|
|
import sentencepiece as spm |
|
|
import requests |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TOKENIZER_PATH = "ko_unigram.model" |
|
|
DATA_PATH = "corpus.txt" |
|
|
MAX_LEN = 128 |
|
|
EMBED_DIM = 384 |
|
|
LATENT_DIM = 384 |
|
|
BATCH_SIZE = 384 |
|
|
NEGATIVE_RATIO = 1 |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_file(url, save_path): |
|
|
r = requests.get(url, stream=True) |
|
|
r.raise_for_status() |
|
|
with open(save_path, "wb") as f: |
|
|
for chunk in r.iter_content(8192*2): |
|
|
f.write(chunk) |
|
|
print(f"Saved {save_path}") |
|
|
|
|
|
if not os.path.exists(TOKENIZER_PATH): |
|
|
download_file( |
|
|
"https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true", |
|
|
TOKENIZER_PATH, |
|
|
) |
|
|
if not os.path.exists(DATA_PATH): |
|
|
download_file( |
|
|
"https://huggingface.co/datasets/Yuchan5386/1/resolve/main/shuffled_corpus.txt?download=true", |
|
|
DATA_PATH, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sp = spm.SentencePieceProcessor(TOKENIZER_PATH) |
|
|
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0 |
|
|
vocab_size = sp.get_piece_size() |
|
|
|
|
|
def encode_sentence(sentence, max_len=MAX_LEN): |
|
|
return sp.encode(sentence, out_type=int)[:max_len] |
|
|
|
|
|
def pad_sentence(tokens): |
|
|
return tokens + [pad_id] * (MAX_LEN - len(tokens)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PairStream(IterableDataset): |
|
|
def __init__(self, txt_path, negative_ratio): |
|
|
self.sentences = [line.strip() for line in open(txt_path, encoding="utf-8") if line.strip()] |
|
|
self.neg_ratio = negative_ratio |
|
|
|
|
|
def __iter__(self): |
|
|
while True: |
|
|
for s1 in self.sentences: |
|
|
x1 = pad_sentence(encode_sentence(s1)) |
|
|
yield (torch.tensor(x1), torch.tensor(x1), torch.tensor(1.0)) |
|
|
for _ in range(self.neg_ratio): |
|
|
s2 = random.choice(self.sentences) |
|
|
x2 = pad_sentence(encode_sentence(s2)) |
|
|
yield (torch.tensor(x1), torch.tensor(x2), torch.tensor(0.0)) |
|
|
|
|
|
stream_ds = PairStream(DATA_PATH, NEGATIVE_RATIO) |
|
|
loader = DataLoader(stream_ds, batch_size=BATCH_SIZE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EncoderBlock(nn.Module): |
|
|
def __init__(self, embed_dim, latent_dim): |
|
|
super().__init__() |
|
|
self.mha = nn.MultiheadAttention(embed_dim, num_heads=8, batch_first=True) |
|
|
self.WB = nn.Linear(embed_dim, embed_dim * 3) |
|
|
self.W = nn.Linear(embed_dim * 3 // 2, embed_dim) |
|
|
self.ln1 = nn.LayerNorm(embed_dim) |
|
|
self.ln2 = nn.LayerNorm(embed_dim) |
|
|
self.ln3 = nn.LayerNorm(embed_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
x1 = self.ln1(x) |
|
|
attn, _ = self.mha(x1, x1, x1) |
|
|
x = attn + x |
|
|
x2 = self.ln2(x) |
|
|
w = self.WB(x2) |
|
|
a, b = torch.chunk(w, 2, dim=-1) |
|
|
g = F.silu(a) * b |
|
|
out = self.W(g) |
|
|
return self.ln3(out) + x |
|
|
|
|
|
class SentenceEncoder(nn.Module): |
|
|
def __init__(self, vocab_size, embed_dim, latent_dim, max_len): |
|
|
super().__init__() |
|
|
self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_id) |
|
|
self.pos = nn.Embedding(max_len, embed_dim) |
|
|
self.blocks = nn.ModuleList([EncoderBlock(embed_dim, latent_dim) for _ in range(2)]) |
|
|
self.ln_f = nn.LayerNorm(embed_dim) |
|
|
self.latent = nn.Linear(embed_dim, latent_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
b, l = x.shape |
|
|
pos_ids = torch.arange(l, device=x.device).unsqueeze(0).expand(b, l) |
|
|
x = self.embed(x) + self.pos(pos_ids) |
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
x = self.ln_f(x) |
|
|
x = x.mean(dim=1) |
|
|
return torch.tanh(self.latent(x)) |
|
|
|
|
|
encoder = SentenceEncoder(vocab_size, EMBED_DIM, LATENT_DIM, MAX_LEN).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cosine_sim(v1, v2, eps=1e-8): |
|
|
dot = (v1 * v2).sum(dim=-1) |
|
|
norm = v1.norm(dim=-1) * v2.norm(dim=-1) + eps |
|
|
return dot / norm |
|
|
|
|
|
def contrastive_loss(pred, label, margin=0.7): |
|
|
dist = 1 - pred |
|
|
pos_loss = label * dist.pow(2) |
|
|
neg_loss = (1 - label) * (torch.clamp(margin - dist, min=0).pow(2)) |
|
|
return (pos_loss + neg_loss).mean() |
|
|
|
|
|
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-5) |
|
|
|
|
|
|
|
|
encoder = torch.compile(encoder) |
|
|
cosine_sim = torch.compile(cosine_sim) |
|
|
contrastive_loss = torch.compile(contrastive_loss) |
|
|
|
|
|
|
|
|
|
|
|
steps_per_epoch = 23119910 // BATCH_SIZE |
|
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
encoder.train() |
|
|
|
|
|
progress = tqdm(range(steps_per_epoch), desc="Training", ncols=120) |
|
|
|
|
|
for step, batch in zip(progress, loader): |
|
|
x1, x2, y = [b.to(device) for b in batch] |
|
|
|
|
|
|
|
|
v1 = encoder(x1) |
|
|
v2 = encoder(x2) |
|
|
pred = cosine_sim(v1, v2) |
|
|
|
|
|
loss = contrastive_loss(pred, y) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
progress.set_postfix({"loss": f"{loss.item():.4f}"}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LIMIT = 4000 |
|
|
prompts = [] |
|
|
for i, line in enumerate(open(DATA_PATH, "r", encoding="utf-8")): |
|
|
if i >= LIMIT: break |
|
|
line = line.strip() |
|
|
if line: |
|
|
prompts.append(line) |
|
|
|
|
|
@torch.no_grad() |
|
|
def get_sentence_vector(sentence): |
|
|
tokens = pad_sentence(encode_sentence(sentence)) |
|
|
x = torch.tensor([tokens]).to(device) |
|
|
return encoder(x).cpu().numpy()[0] |
|
|
|
|
|
if os.path.exists("corpus_vectors.npy"): |
|
|
corpus_vectors = np.load("corpus_vectors.npy") |
|
|
else: |
|
|
corpus_vectors = np.stack([get_sentence_vector(p) for p in prompts]).astype(np.float16) |
|
|
np.save("corpus_vectors.npy", corpus_vectors) |
|
|
|
|
|
corpus_norms = np.linalg.norm(corpus_vectors, axis=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def search(query, top_k=3): |
|
|
q_vec = get_sentence_vector(query).astype(np.float16) |
|
|
sims = corpus_vectors @ q_vec |
|
|
sims /= (corpus_norms * np.linalg.norm(q_vec) + 1e-8) |
|
|
top_idx = np.argsort(sims)[::-1][:top_k] |
|
|
return [(prompts[i], float(sims[i])) for i in top_idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query = "μ μ¬μ΄λ μ λ
μ μ°λ¦¬μ ν¨κ» λ¨Ήμ 건κ°μ?" |
|
|
results = search(query) |
|
|
for p, s in results: |
|
|
print(f"Prompt: {p}\nμ μ¬λ: {s:.3f}\n---") |
|
|
|