Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from torch.optim import Adam | |
| from torch.utils.data import DataLoader, Dataset | |
| from pymongo import MongoClient | |
| from transformers import BertTokenizer, BertModel | |
| import numpy as np | |
| # MongoDB Atlas 연결 설정 | |
| client = MongoClient( | |
| "mongodb+srv://waseoke:rookies3@cluster0.ps7gq.mongodb.net/test?retryWrites=true&w=majority" | |
| ) | |
| db = client["two_tower_model"] | |
| train_dataset = db["train_dataset"] | |
| # KoBERT 모델 및 토크나이저 로드 | |
| tokenizer = BertTokenizer.from_pretrained('monologg/kobert') | |
| model = BertModel.from_pretrained('monologg/kobert') | |
| # 상품 임베딩 함수 | |
| def embed_product_data(product): | |
| """ | |
| 상품 데이터를 KoBERT로 임베딩하는 함수. | |
| """ | |
| text = ( | |
| product.get("product_name", "") + " " + product.get("product_description", "") | |
| ) | |
| inputs = tokenizer( | |
| text, return_tensors="pt", truncation=True, padding=True, max_length=128 | |
| ) | |
| outputs = model(**inputs) | |
| embedding = ( | |
| outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten() | |
| ) # 평균 풀링 | |
| return embedding | |
| # PyTorch Dataset 정의 | |
| class TripletDataset(Dataset): | |
| def __init__(self, dataset): | |
| self.dataset = dataset | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| data = self.dataset[idx] | |
| anchor = torch.tensor(data["anchor_embedding"], dtype=torch.float32) | |
| positive = torch.tensor(data["positive_embedding"], dtype=torch.float32) | |
| negative = torch.tensor(data["negative_embedding"], dtype=torch.float32) | |
| return anchor, positive, negative | |
| # MongoDB에서 데이터셋 로드 및 임베딩 변환 | |
| def prepare_training_data(verbose=False): | |
| dataset = list(train_dataset.find()) | |
| if not dataset: | |
| raise ValueError("No training data found in MongoDB.") | |
| # Anchor, Positive, Negative 임베딩 생성 | |
| embedded_dataset = [] | |
| for idx, entry in enumerate(dataset): | |
| try: | |
| # Anchor, Positive, Negative 데이터 임베딩 | |
| anchor_embedding = embed_product_data(entry["anchor"]["product"]) | |
| positive_embedding = embed_product_data(entry["positive"]["product"]) | |
| negative_embedding = embed_product_data(entry["negative"]["product"]) | |
| # 임베딩 확인 (옵션으로 출력) | |
| if verbose: | |
| print(f"Sample {idx + 1}:") | |
| print( | |
| f"Anchor Embedding: {anchor_embedding[:5]}... (shape: {anchor_embedding.shape})" | |
| ) | |
| print( | |
| f"Positive Embedding: {positive_embedding[:5]}... (shape: {positive_embedding.shape})" | |
| ) | |
| print( | |
| f"Negative Embedding: {negative_embedding[:5]}... (shape: {negative_embedding.shape})" | |
| ) | |
| # 임베딩 결과 저장 | |
| embedded_dataset.append( | |
| { | |
| "anchor_embedding": anchor_embedding, | |
| "positive_embedding": positive_embedding, | |
| "negative_embedding": negative_embedding, | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"Error embedding data at sample {idx + 1}: {e}") | |
| return TripletDataset(embedded_dataset) | |
| # 데이터셋 검증용 함수 | |
| def validate_embeddings(): | |
| """ | |
| 데이터셋 임베딩을 생성하고 각 임베딩의 일부를 출력하여 확인. | |
| """ | |
| print("Validating embeddings...") | |
| triplet_dataset = prepare_training_data(verbose=True) | |
| print(f"Total samples: {len(triplet_dataset)}") | |
| return triplet_dataset | |
| # Triplet Loss를 학습시키는 함수 | |
| def train_triplet_model( | |
| product_model, train_loader, num_epochs=10, learning_rate=0.001, margin=0.05 | |
| ): | |
| optimizer = Adam(product_model.parameters(), lr=learning_rate) | |
| for epoch in range(num_epochs): | |
| product_model.train() | |
| total_loss = 0 | |
| for anchor, positive, negative in train_loader: | |
| optimizer.zero_grad() | |
| # Forward pass | |
| anchor_vec = product_model(anchor) | |
| positive_vec = product_model(positive) | |
| negative_vec = product_model(negative) | |
| # Triplet loss 계산 | |
| positive_distance = F.pairwise_distance(anchor_vec, positive_vec) | |
| negative_distance = F.pairwise_distance(anchor_vec, negative_vec) | |
| triplet_loss = torch.clamp( | |
| positive_distance - negative_distance + margin, min=0 | |
| ).mean() | |
| # 역전파와 최적화 | |
| triplet_loss.backward() | |
| optimizer.step() | |
| total_loss += triplet_loss.item() | |
| print( | |
| f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}" | |
| ) | |
| return product_model | |
| # 모델 학습 파이프라인 | |
| def main(): | |
| # 모델 초기화 (예시 모델) | |
| product_model = torch.nn.Sequential( | |
| torch.nn.Linear(768, 256), # 768: KoBERT 임베딩 차원 | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(256, 128), | |
| ) | |
| # 데이터 준비 | |
| triplet_dataset = prepare_training_data() | |
| train_loader = DataLoader(triplet_dataset, batch_size=16, shuffle=True) | |
| # 모델 학습 | |
| trained_model = train_triplet_model(product_model, train_loader) | |
| # 학습된 모델 저장 | |
| torch.save(trained_model.state_dict(), "product_model.pth") | |
| print("Model training completed and saved.") | |
| print(validate_embeddings()) | |
| if __name__ == "__main__": | |
| main() | |