model-prototype / Model_torch.py
Yuchan
Update Model_torch.py
cc78280 verified
import os, json, random, numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
import sentencepiece as spm
import requests
# ===============================
# 0️⃣ ν™˜κ²½ μ„€μ •
# ===============================
TOKENIZER_PATH = "ko_unigram.model"
DATA_PATH = "corpus.txt"
MAX_LEN = 128
EMBED_DIM = 384
LATENT_DIM = 384
BATCH_SIZE = 384
NEGATIVE_RATIO = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ===============================
# 1️⃣ 파일 λ‹€μš΄λ‘œλ“œ
# ===============================
def download_file(url, save_path):
r = requests.get(url, stream=True)
r.raise_for_status()
with open(save_path, "wb") as f:
for chunk in r.iter_content(8192*2):
f.write(chunk)
print(f"Saved {save_path}")
if not os.path.exists(TOKENIZER_PATH):
download_file(
"https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true",
TOKENIZER_PATH,
)
if not os.path.exists(DATA_PATH):
download_file(
"https://huggingface.co/datasets/Yuchan5386/1/resolve/main/shuffled_corpus.txt?download=true",
DATA_PATH,
)
# ===============================
# 2️⃣ ν† ν¬λ‚˜μ΄μ € μ€€λΉ„
# ===============================
sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
vocab_size = sp.get_piece_size()
def encode_sentence(sentence, max_len=MAX_LEN):
return sp.encode(sentence, out_type=int)[:max_len]
def pad_sentence(tokens):
return tokens + [pad_id] * (MAX_LEN - len(tokens))
# ===============================
# 3️⃣ Streaming Dataset
# ===============================
class PairStream(IterableDataset):
def __init__(self, txt_path, negative_ratio):
self.sentences = [line.strip() for line in open(txt_path, encoding="utf-8") if line.strip()]
self.neg_ratio = negative_ratio
def __iter__(self):
while True:
for s1 in self.sentences:
x1 = pad_sentence(encode_sentence(s1))
yield (torch.tensor(x1), torch.tensor(x1), torch.tensor(1.0))
for _ in range(self.neg_ratio):
s2 = random.choice(self.sentences)
x2 = pad_sentence(encode_sentence(s2))
yield (torch.tensor(x1), torch.tensor(x2), torch.tensor(0.0))
stream_ds = PairStream(DATA_PATH, NEGATIVE_RATIO)
loader = DataLoader(stream_ds, batch_size=BATCH_SIZE)
# ===============================
# 4️⃣ Sentence Encoder μ •μ˜
# ===============================
class EncoderBlock(nn.Module):
def __init__(self, embed_dim, latent_dim):
super().__init__()
self.mha = nn.MultiheadAttention(embed_dim, num_heads=8, batch_first=True)
self.WB = nn.Linear(embed_dim, embed_dim * 3)
self.W = nn.Linear(embed_dim * 3 // 2, embed_dim)
self.ln1 = nn.LayerNorm(embed_dim)
self.ln2 = nn.LayerNorm(embed_dim)
self.ln3 = nn.LayerNorm(embed_dim)
def forward(self, x):
x1 = self.ln1(x)
attn, _ = self.mha(x1, x1, x1)
x = attn + x
x2 = self.ln2(x)
w = self.WB(x2)
a, b = torch.chunk(w, 2, dim=-1)
g = F.silu(a) * b
out = self.W(g)
return self.ln3(out) + x
class SentenceEncoder(nn.Module):
def __init__(self, vocab_size, embed_dim, latent_dim, max_len):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_id)
self.pos = nn.Embedding(max_len, embed_dim)
self.blocks = nn.ModuleList([EncoderBlock(embed_dim, latent_dim) for _ in range(2)])
self.ln_f = nn.LayerNorm(embed_dim)
self.latent = nn.Linear(embed_dim, latent_dim)
def forward(self, x):
b, l = x.shape
pos_ids = torch.arange(l, device=x.device).unsqueeze(0).expand(b, l)
x = self.embed(x) + self.pos(pos_ids)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
x = x.mean(dim=1)
return torch.tanh(self.latent(x))
encoder = SentenceEncoder(vocab_size, EMBED_DIM, LATENT_DIM, MAX_LEN).to(device)
# ===============================
# 5️⃣ Cosine + Contrastive Loss
# ===============================
def cosine_sim(v1, v2, eps=1e-8):
dot = (v1 * v2).sum(dim=-1)
norm = v1.norm(dim=-1) * v2.norm(dim=-1) + eps
return dot / norm
def contrastive_loss(pred, label, margin=0.7):
dist = 1 - pred
pos_loss = label * dist.pow(2)
neg_loss = (1 - label) * (torch.clamp(margin - dist, min=0).pow(2))
return (pos_loss + neg_loss).mean()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-5)
encoder = torch.compile(encoder)
cosine_sim = torch.compile(cosine_sim)
contrastive_loss = torch.compile(contrastive_loss)
# ===============================
# 6️⃣ ν•™μŠ΅ 루프
# ===============================
steps_per_epoch = 23119910 // BATCH_SIZE
from tqdm import tqdm
encoder.train()
progress = tqdm(range(steps_per_epoch), desc="Training", ncols=120)
for step, batch in zip(progress, loader):
x1, x2, y = [b.to(device) for b in batch]
# forward
v1 = encoder(x1)
v2 = encoder(x2)
pred = cosine_sim(v1, v2)
loss = contrastive_loss(pred, y)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
# πŸ“‰ tqdm에 loss ν‘œμ‹œ
progress.set_postfix({"loss": f"{loss.item():.4f}"})
# ===============================
# 7️⃣ κ²€μƒ‰μš© 벑터 생성
# ===============================
LIMIT = 4000
prompts = []
for i, line in enumerate(open(DATA_PATH, "r", encoding="utf-8")):
if i >= LIMIT: break
line = line.strip()
if line:
prompts.append(line)
@torch.no_grad()
def get_sentence_vector(sentence):
tokens = pad_sentence(encode_sentence(sentence))
x = torch.tensor([tokens]).to(device)
return encoder(x).cpu().numpy()[0]
if os.path.exists("corpus_vectors.npy"):
corpus_vectors = np.load("corpus_vectors.npy")
else:
corpus_vectors = np.stack([get_sentence_vector(p) for p in prompts]).astype(np.float16)
np.save("corpus_vectors.npy", corpus_vectors)
corpus_norms = np.linalg.norm(corpus_vectors, axis=1)
# ===============================
# 8️⃣ 검색 ν•¨μˆ˜
# ===============================
def search(query, top_k=3):
q_vec = get_sentence_vector(query).astype(np.float16)
sims = corpus_vectors @ q_vec
sims /= (corpus_norms * np.linalg.norm(q_vec) + 1e-8)
top_idx = np.argsort(sims)[::-1][:top_k]
return [(prompts[i], float(sims[i])) for i in top_idx]
# ===============================
# πŸ”Ÿ ν…ŒμŠ€νŠΈ
# ===============================
query = "μ μ‹¬μ΄λ‚˜ 저녁을 μš°λ¦¬μ™€ ν•¨κ»˜ 먹을 κ±΄κ°€μš”?"
results = search(query)
for p, s in results:
print(f"Prompt: {p}\nμœ μ‚¬λ„: {s:.3f}\n---")