|
|
import os |
|
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache' |
|
|
|
|
|
import multiprocessing |
|
|
try: |
|
|
multiprocessing.set_start_method('spawn') |
|
|
except RuntimeError: |
|
|
pass |
|
|
import json |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from typing import List, Dict |
|
|
import logging |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import random |
|
|
from collections import defaultdict |
|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
import multiprocessing |
|
|
from multiprocessing import Pool |
|
|
import psutil |
|
|
import argparse |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_NUM_TRIPLETS = 150 |
|
|
DEFAULT_NUM_EPOCHS = 1 |
|
|
DEFAULT_BATCH_SIZE = 64 |
|
|
DEFAULT_LEARNING_RATE = 0.001 |
|
|
DEFAULT_OUTPUT_DIM = 256 |
|
|
DEFAULT_MAX_SEQ_LENGTH = 15 |
|
|
DEFAULT_SAVE_INTERVAL = 2 |
|
|
DEFAULT_DATA_PATH = "./users.json" |
|
|
DEFAULT_OUTPUT_DIR = "./model" |
|
|
|
|
|
|
|
|
NUM_TRIPLETS = int(os.environ.get("NUM_TRIPLETS", DEFAULT_NUM_TRIPLETS)) |
|
|
NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", DEFAULT_NUM_EPOCHS)) |
|
|
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", DEFAULT_BATCH_SIZE)) |
|
|
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", DEFAULT_LEARNING_RATE)) |
|
|
OUTPUT_DIM = int(os.environ.get("OUTPUT_DIM", DEFAULT_OUTPUT_DIM)) |
|
|
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", DEFAULT_MAX_SEQ_LENGTH)) |
|
|
SAVE_INTERVAL = int(os.environ.get("SAVE_INTERVAL", DEFAULT_SAVE_INTERVAL)) |
|
|
DATA_PATH = os.environ.get("DATA_PATH", DEFAULT_DATA_PATH) |
|
|
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", DEFAULT_OUTPUT_DIR) |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logging.info(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UserEmbeddingModel(nn.Module): |
|
|
def __init__(self, vocab_sizes: Dict[str, int], embedding_dims: Dict[str, int], |
|
|
output_dim: int = 256, max_sequence_length: int = 15, |
|
|
padded_fields_length: int = 10): |
|
|
super().__init__() |
|
|
|
|
|
self.max_sequence_length = max_sequence_length |
|
|
self.padded_fields_length = padded_fields_length |
|
|
self.padded_fields = {'dmp_channels', 'dmp_tags', 'dmp_clusters'} |
|
|
self.embedding_layers = nn.ModuleDict() |
|
|
|
|
|
|
|
|
for field, vocab_size in vocab_sizes.items(): |
|
|
self.embedding_layers[field] = nn.Embedding( |
|
|
vocab_size, |
|
|
embedding_dims.get(field, 16), |
|
|
padding_idx=0 |
|
|
) |
|
|
|
|
|
|
|
|
self.total_input_dim = 0 |
|
|
for field, dim in embedding_dims.items(): |
|
|
if field in self.padded_fields: |
|
|
self.total_input_dim += dim |
|
|
else: |
|
|
self.total_input_dim += dim |
|
|
|
|
|
print(f"Total input dimension: {self.total_input_dim}") |
|
|
|
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(self.total_input_dim, self.total_input_dim // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(self.total_input_dim // 2, output_dim), |
|
|
nn.LayerNorm(output_dim) |
|
|
) |
|
|
|
|
|
def _process_sequence(self, embedding_layer: nn.Embedding, indices: torch.Tensor, |
|
|
field_name: str) -> torch.Tensor: |
|
|
"""Process normal sequences""" |
|
|
batch_size = indices.size(0) |
|
|
if indices.numel() == 0: |
|
|
return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device) |
|
|
|
|
|
if field_name in ['dmp_city', 'dmp_domains']: |
|
|
if indices.dim() == 1: |
|
|
indices = indices.unsqueeze(0) |
|
|
if indices.size(1) > 0: |
|
|
return embedding_layer(indices[:, 0]) |
|
|
return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device) |
|
|
|
|
|
|
|
|
embeddings = embedding_layer(indices) |
|
|
return embeddings.mean(dim=1) |
|
|
|
|
|
def _process_padded_sequence(self, embedding_layer: nn.Embedding, |
|
|
indices: torch.Tensor) -> torch.Tensor: |
|
|
"""Process sequences with padding""" |
|
|
batch_size = indices.size(0) |
|
|
emb_dim = embedding_layer.embedding_dim |
|
|
|
|
|
|
|
|
embeddings = embedding_layer(indices) |
|
|
|
|
|
|
|
|
mask = (indices != 0).float().unsqueeze(-1) |
|
|
masked_embeddings = embeddings * mask |
|
|
sum_mask = mask.sum(dim=1).clamp(min=1.0) |
|
|
|
|
|
return (masked_embeddings.sum(dim=1) / sum_mask) |
|
|
|
|
|
def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
batch_embeddings = [] |
|
|
|
|
|
for field in ['dmp_city', 'source', 'dmp_brands', |
|
|
'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels', |
|
|
'device']: |
|
|
if field in inputs and field in self.embedding_layers: |
|
|
if field in self.padded_fields: |
|
|
emb = self._process_padded_sequence( |
|
|
self.embedding_layers[field], |
|
|
inputs[field] |
|
|
) |
|
|
else: |
|
|
emb = self._process_sequence( |
|
|
self.embedding_layers[field], |
|
|
inputs[field], |
|
|
field |
|
|
) |
|
|
batch_embeddings.append(emb) |
|
|
|
|
|
combined = torch.cat(batch_embeddings, dim=1) |
|
|
return self.fc(combined) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UserEmbeddingPipeline: |
|
|
def __init__(self, output_dim: int = 256, max_sequence_length: int = 15): |
|
|
self.output_dim = output_dim |
|
|
self.max_sequence_length = max_sequence_length |
|
|
self.model = None |
|
|
self.vocab_maps = {} |
|
|
|
|
|
|
|
|
self.fields = [ |
|
|
'dmp_city', 'source', 'dmp_brands', |
|
|
'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels', |
|
|
'device' |
|
|
] |
|
|
|
|
|
|
|
|
self.field_mapping = { |
|
|
'dmp_city': ('dmp', 'city'), |
|
|
'source': ('dmp', '', 'source'), |
|
|
'dmp_brands': ('dmp', 'brands'), |
|
|
'dmp_clusters': ('dmp', 'clusters'), |
|
|
'dmp_industries': ('dmp', 'industries'), |
|
|
'dmp_tags': ('dmp', 'tags'), |
|
|
'dmp_channels': ('dmp', 'channels'), |
|
|
'device': ('device',) |
|
|
} |
|
|
|
|
|
self.embedding_dims = { |
|
|
'dmp_city': 8, |
|
|
'source': 8, |
|
|
'dmp_brands': 32, |
|
|
'dmp_clusters': 64, |
|
|
'dmp_industries': 32, |
|
|
'dmp_tags': 128, |
|
|
'dmp_channels': 64, |
|
|
'device': 8 |
|
|
} |
|
|
|
|
|
def _clean_value(self, value): |
|
|
if isinstance(value, float) and np.isnan(value): |
|
|
return [] |
|
|
if isinstance(value, str): |
|
|
return [value.lower().strip()] |
|
|
if isinstance(value, list): |
|
|
return [str(v).lower().strip() for v in value if v is not None and str(v).strip()] |
|
|
return [] |
|
|
|
|
|
def _get_field_from_user(self, user, field): |
|
|
"""Extract field value from new JSON user format""" |
|
|
mapping = self.field_mapping.get(field, (field,)) |
|
|
value = user |
|
|
|
|
|
|
|
|
for key in mapping: |
|
|
if isinstance(value, dict): |
|
|
value = value.get(key, {}) |
|
|
else: |
|
|
|
|
|
|
|
|
value = [] |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
if field in {'dmp_brands', 'dmp_channels', 'dmp_clusters', 'dmp_industries', 'dmp_tags'} and not isinstance(value, list): |
|
|
|
|
|
if value and not isinstance(value, dict): |
|
|
value = [value] |
|
|
else: |
|
|
value = [] |
|
|
|
|
|
return value |
|
|
|
|
|
def build_vocabularies(self, users_data: List[Dict]) -> Dict[str, Dict[str, int]]: |
|
|
field_values = {field: {'<PAD>'} for field in self.fields} |
|
|
|
|
|
|
|
|
users = [] |
|
|
for data in users_data: |
|
|
|
|
|
if 'raw_json' in data and 'user' in data['raw_json']: |
|
|
users.append(data['raw_json']['user']) |
|
|
|
|
|
elif 'user' in data: |
|
|
users.append(data['user']) |
|
|
else: |
|
|
users.append(data) |
|
|
|
|
|
for user in users: |
|
|
for field in self.fields: |
|
|
values = self._clean_value(self._get_field_from_user(user, field)) |
|
|
field_values[field].update(values) |
|
|
|
|
|
self.vocab_maps = { |
|
|
field: {val: idx for idx, val in enumerate(sorted(values))} |
|
|
for field, values in field_values.items() |
|
|
} |
|
|
|
|
|
return self.vocab_maps |
|
|
|
|
|
def _prepare_input(self, user: Dict) -> Dict[str, torch.Tensor]: |
|
|
inputs = {} |
|
|
|
|
|
for field in self.fields: |
|
|
values = self._clean_value(self._get_field_from_user(user, field)) |
|
|
vocab = self.vocab_maps[field] |
|
|
indices = [vocab.get(val, 0) for val in values] |
|
|
inputs[field] = torch.tensor(indices, dtype=torch.long) |
|
|
|
|
|
return inputs |
|
|
|
|
|
def initialize_model(self) -> None: |
|
|
vocab_sizes = {field: len(vocab) for field, vocab in self.vocab_maps.items()} |
|
|
|
|
|
self.model = UserEmbeddingModel( |
|
|
vocab_sizes=vocab_sizes, |
|
|
embedding_dims=self.embedding_dims, |
|
|
output_dim=self.output_dim, |
|
|
max_sequence_length=self.max_sequence_length |
|
|
) |
|
|
self.model.to(device) |
|
|
self.model.eval() |
|
|
|
|
|
def generate_embeddings(self, users_data: List[Dict], batch_size: int = 32) -> Dict[str, np.ndarray]: |
|
|
"""Generate embeddings for all users""" |
|
|
embeddings = {} |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
users = [] |
|
|
user_ids = [] |
|
|
|
|
|
for data in users_data: |
|
|
|
|
|
if 'raw_json' in data and 'user' in data['raw_json']: |
|
|
user = data['raw_json']['user'] |
|
|
users.append(user) |
|
|
|
|
|
if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']: |
|
|
user_ids.append(str(user['dmp']['']['id'])) |
|
|
else: |
|
|
|
|
|
user_ids.append(str(user.get('uid', user.get('id', None)))) |
|
|
|
|
|
elif 'user' in data: |
|
|
user = data['user'] |
|
|
users.append(user) |
|
|
|
|
|
if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']: |
|
|
user_ids.append(str(user['dmp']['']['id'])) |
|
|
else: |
|
|
|
|
|
user_ids.append(str(user.get('uid', user.get('id', None)))) |
|
|
else: |
|
|
users.append(data) |
|
|
|
|
|
if 'dmp' in data and '' in data['dmp'] and 'id' in data['dmp']['']: |
|
|
user_ids.append(str(data['dmp']['']['id'])) |
|
|
else: |
|
|
|
|
|
user_ids.append(str(data.get('uid', data.get('id', None)))) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in tqdm(range(0, len(users), batch_size), desc="Generating embeddings"): |
|
|
batch_users = users[i:i+batch_size] |
|
|
batch_ids = user_ids[i:i+batch_size] |
|
|
batch_inputs = [] |
|
|
valid_indices = [] |
|
|
|
|
|
for j, user in enumerate(batch_users): |
|
|
if batch_ids[j] is not None: |
|
|
batch_inputs.append(self._prepare_input(user)) |
|
|
valid_indices.append(j) |
|
|
|
|
|
if batch_inputs: |
|
|
|
|
|
anchor_batch, _, _ = collate_batch([(inputs, inputs, inputs) for inputs in batch_inputs]) |
|
|
|
|
|
|
|
|
anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()} |
|
|
|
|
|
|
|
|
batch_embeddings = self.model(anchor_batch).cpu() |
|
|
|
|
|
|
|
|
for j, idx in enumerate(valid_indices): |
|
|
if batch_ids[idx]: |
|
|
embeddings[batch_ids[idx]] = batch_embeddings[j].numpy() |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def save_embeddings(self, embeddings: Dict[str, np.ndarray], output_dir: str) -> None: |
|
|
"""Save embeddings to file""" |
|
|
output_dir = Path(output_dir) |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
json_path = output_dir / 'embeddings.json' |
|
|
with open(json_path, 'w') as f: |
|
|
json_embeddings = {user_id: emb.tolist() for user_id, emb in embeddings.items()} |
|
|
json.dump(json_embeddings, f) |
|
|
|
|
|
|
|
|
npy_path = output_dir / 'embeddings.npz' |
|
|
np.savez_compressed(npy_path, |
|
|
embeddings=np.stack(list(embeddings.values())), |
|
|
user_ids=np.array(list(embeddings.keys()))) |
|
|
|
|
|
|
|
|
vocab_path = output_dir / 'vocabularies.json' |
|
|
with open(vocab_path, 'w') as f: |
|
|
json.dump(self.vocab_maps, f) |
|
|
|
|
|
logging.info(f"\nEmbeddings saved in {output_dir}:") |
|
|
logging.info(f"- Embeddings JSON: {json_path}") |
|
|
logging.info(f"- Embeddings NPY: {npy_path}") |
|
|
logging.info(f"- Vocabularies: {vocab_path}") |
|
|
|
|
|
def save_model(self, output_dir: str) -> None: |
|
|
"""Save model in PyTorch format (.pth)""" |
|
|
output_dir = Path(output_dir) |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
model_path = output_dir / 'model.pth' |
|
|
|
|
|
|
|
|
checkpoint = { |
|
|
'model_state_dict': self.model.state_dict(), |
|
|
'vocab_maps': self.vocab_maps, |
|
|
'embedding_dims': self.embedding_dims, |
|
|
'output_dim': self.output_dim, |
|
|
'max_sequence_length': self.max_sequence_length |
|
|
} |
|
|
|
|
|
|
|
|
torch.save(checkpoint, model_path) |
|
|
|
|
|
logging.info(f"Model saved to: {model_path}") |
|
|
|
|
|
|
|
|
config_info = { |
|
|
'model_type': 'UserEmbeddingModel', |
|
|
'vocab_sizes': {field: len(vocab) for field, vocab in self.vocab_maps.items()}, |
|
|
'embedding_dims': self.embedding_dims, |
|
|
'output_dim': self.output_dim, |
|
|
'max_sequence_length': self.max_sequence_length, |
|
|
'padded_fields': list(self.model.padded_fields), |
|
|
'fields': self.fields |
|
|
} |
|
|
|
|
|
config_path = output_dir / 'model_config.json' |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(config_info, f, indent=2) |
|
|
|
|
|
logging.info(f"Model configuration saved to: {config_path}") |
|
|
|
|
|
|
|
|
hf_dir = output_dir / 'huggingface' |
|
|
hf_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
torch.save(self.model.state_dict(), hf_dir / 'pytorch_model.bin') |
|
|
|
|
|
|
|
|
with open(hf_dir / 'config.json', 'w') as f: |
|
|
json.dump(config_info, f, indent=2) |
|
|
|
|
|
logging.info(f"Model saved in HuggingFace format to: {hf_dir}") |
|
|
|
|
|
def load_model(self, model_path: str) -> None: |
|
|
"""Load a previously saved model""" |
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
|
|
|
|
|
|
self.vocab_maps = checkpoint.get('vocab_maps', self.vocab_maps) |
|
|
self.embedding_dims = checkpoint.get('embedding_dims', self.embedding_dims) |
|
|
self.output_dim = checkpoint.get('output_dim', self.output_dim) |
|
|
self.max_sequence_length = checkpoint.get('max_sequence_length', self.max_sequence_length) |
|
|
|
|
|
|
|
|
if self.model is None: |
|
|
self.initialize_model() |
|
|
|
|
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
self.model.to(device) |
|
|
self.model.eval() |
|
|
|
|
|
logging.info(f"Model loaded from: {model_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_similarity(user1, user2, pipeline, filtered_tags=None): |
|
|
try: |
|
|
|
|
|
channels1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_channels') if c is not None) |
|
|
channels2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_channels') if c is not None) |
|
|
clusters1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_clusters') if c is not None) |
|
|
clusters2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_clusters') if c is not None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tags1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'dmp_tags') if c is not None) |
|
|
tags2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'dmp_tags') if c is not None) |
|
|
|
|
|
|
|
|
source1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'source') if c is not None) |
|
|
source2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'source') if c is not None) |
|
|
|
|
|
brands1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'dmp_brands') if c is not None) |
|
|
brands2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'dmp_brands') if c is not None) |
|
|
|
|
|
device1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'device') if c is not None) |
|
|
device2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'device') if c is not None) |
|
|
|
|
|
if filtered_tags is not None: |
|
|
|
|
|
tags1 = {tag for tag in tags1 if tag in filtered_tags} |
|
|
tags2 = {tag for tag in tags2 if tag in filtered_tags} |
|
|
|
|
|
|
|
|
channel_sim = len(channels1 & channels2) / max(1, len(channels1 | channels2)) |
|
|
cluster_sim = len(clusters1 & clusters2) / max(1, len(clusters1 | channels2)) |
|
|
|
|
|
tag_sim = len(tags1 & tags2) / max(1, len(tags1 | tags2)) |
|
|
|
|
|
|
|
|
source_sim = len(source1 & source2) / max(1, len(source1 | source2)) |
|
|
brands_sim = len(brands1 & brands2) / max(1, len(brands1 | brands2)) |
|
|
device_sim = len(device1 & device2) / max(1, len(device1 | device2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_weight = 6 + 5 + 3 + 2 + 5 + 3 |
|
|
weighted_sim = ( |
|
|
6 * cluster_sim + |
|
|
5 * channel_sim + |
|
|
3 * tag_sim + |
|
|
2 * source_sim + |
|
|
5 * brands_sim + |
|
|
3 * device_sim |
|
|
) / total_weight |
|
|
|
|
|
return weighted_sim |
|
|
except Exception as e: |
|
|
logging.error(f"Error calculating similarity: {str(e)}") |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
def process_batch_triplets(args): |
|
|
try: |
|
|
batch_idx, users, channel_index, cluster_index, num_triplets, pipeline = args |
|
|
batch_triplets = [] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
temp_device = torch.device("cpu") |
|
|
|
|
|
for _ in range(num_triplets): |
|
|
anchor_idx = random.randint(0, len(users)-1) |
|
|
anchor_user = users[anchor_idx] |
|
|
|
|
|
|
|
|
candidates = set() |
|
|
for channel in pipeline._get_field_from_user(anchor_user, 'dmp_channels'): |
|
|
candidates.update(channel_index.get(str(channel), [])) |
|
|
for cluster in pipeline._get_field_from_user(anchor_user, 'dmp_clusters'): |
|
|
candidates.update(cluster_index.get(str(cluster), [])) |
|
|
|
|
|
|
|
|
candidates.discard(anchor_idx) |
|
|
|
|
|
|
|
|
if not candidates: |
|
|
positive_idx = random.randint(0, len(users)-1) |
|
|
else: |
|
|
|
|
|
similarities = [] |
|
|
for idx in candidates: |
|
|
|
|
|
sim = calculate_similarity(anchor_user, users[idx], pipeline) |
|
|
if sim > 0: |
|
|
similarities.append((idx, sim)) |
|
|
|
|
|
if not similarities: |
|
|
positive_idx = random.randint(0, len(users)-1) |
|
|
else: |
|
|
|
|
|
similarities.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
top_k = min(10, len(similarities)) |
|
|
positive_idx = similarities[random.randint(0, top_k-1)][0] |
|
|
|
|
|
|
|
|
max_attempts = 50 |
|
|
negative_idx = None |
|
|
|
|
|
for _ in range(max_attempts): |
|
|
idx = random.randint(0, len(users)-1) |
|
|
if idx != anchor_idx and idx != positive_idx: |
|
|
|
|
|
if calculate_similarity(anchor_user, users[idx], pipeline) < 0.1: |
|
|
negative_idx = idx |
|
|
break |
|
|
|
|
|
if negative_idx is None: |
|
|
negative_idx = random.randint(0, len(users)-1) |
|
|
|
|
|
batch_triplets.append((anchor_idx, positive_idx, negative_idx)) |
|
|
|
|
|
return batch_triplets |
|
|
except Exception as e: |
|
|
logging.error(f"Error in batch triplet generation: {str(e)}") |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UserSimilarityDataset(Dataset): |
|
|
def __init__(self, pipeline, users_data, num_triplets=10, num_workers=None, filtered_tags=None): |
|
|
self.triplets = [] |
|
|
self.filtered_tags = filtered_tags |
|
|
logging.info("Initializing UserSimilarityDataset...") |
|
|
|
|
|
|
|
|
self.users = [] |
|
|
for data in users_data: |
|
|
|
|
|
if 'raw_json' in data and 'user' in data['raw_json']: |
|
|
self.users.append(data['raw_json']['user']) |
|
|
|
|
|
elif 'user' in data: |
|
|
self.users.append(data['user']) |
|
|
else: |
|
|
self.users.append(data) |
|
|
|
|
|
self.pipeline = pipeline |
|
|
self.num_triplets = num_triplets |
|
|
|
|
|
|
|
|
if num_workers is None: |
|
|
num_workers = max(1, min(8, os.cpu_count())) |
|
|
self.num_workers = num_workers |
|
|
|
|
|
|
|
|
self.preprocessed_inputs = {} |
|
|
for idx, user in enumerate(self.users): |
|
|
self.preprocessed_inputs[idx] = pipeline._prepare_input(user) |
|
|
|
|
|
logging.info("Creating indexes for channels, clusters, tags, brands, source, and device...") |
|
|
self.channel_index = defaultdict(list) |
|
|
self.cluster_index = defaultdict(list) |
|
|
self.tag_index = defaultdict(list) |
|
|
|
|
|
self.brands_index = defaultdict(list) |
|
|
self.source_index = defaultdict(list) |
|
|
self.device_index = defaultdict(list) |
|
|
|
|
|
for idx, user in enumerate(self.users): |
|
|
channels = pipeline._get_field_from_user(user, 'dmp_channels') |
|
|
clusters = pipeline._get_field_from_user(user, 'dmp_clusters') |
|
|
tags = pipeline._get_field_from_user(user, 'dmp_tags') |
|
|
|
|
|
brands = pipeline._get_field_from_user(user, 'dmp_brands') |
|
|
source = pipeline._get_field_from_user(user, 'source') |
|
|
device = pipeline._get_field_from_user(user, 'device') |
|
|
|
|
|
if channels: |
|
|
channels = [str(c) for c in channels if c is not None] |
|
|
if clusters: |
|
|
clusters = [str(c) for c in clusters if c is not None] |
|
|
if tags: |
|
|
tags = [str(c) for c in tags if c is not None] |
|
|
|
|
|
if self.filtered_tags: |
|
|
tags = [tag for tag in tags if tag in self.filtered_tags] |
|
|
|
|
|
if brands: |
|
|
brands = [str(c) for c in brands if c is not None] |
|
|
|
|
|
|
|
|
if source: |
|
|
if isinstance(source, str): |
|
|
source = [source] |
|
|
else: |
|
|
source = [str(c) for c in source if c is not None] |
|
|
|
|
|
if device: |
|
|
if isinstance(device, str): |
|
|
device = [device] |
|
|
else: |
|
|
device = [str(c) for c in device if c is not None] |
|
|
|
|
|
for channel in channels: |
|
|
self.channel_index[channel].append(idx) |
|
|
for cluster in clusters: |
|
|
self.cluster_index[cluster].append(idx) |
|
|
for tag in tags: |
|
|
self.tag_index[tag].append(idx) |
|
|
|
|
|
for brand in brands: |
|
|
self.brands_index[brand].append(idx) |
|
|
for s in source: |
|
|
self.source_index[s].append(idx) |
|
|
for d in device: |
|
|
self.device_index[d].append(idx) |
|
|
|
|
|
logging.info(f"Found {len(self.channel_index)} unique channels, {len(self.cluster_index)} unique clusters, {len(self.tag_index)} unique tags") |
|
|
logging.info(f"Found {len(self.brands_index)} unique brands, {len(self.source_index)} unique sources, and {len(self.device_index)} unique devices") |
|
|
|
|
|
logging.info(f"Generating triplets using {self.num_workers} worker processes...") |
|
|
|
|
|
self.triplets = self._generate_triplets_gpu(num_triplets) |
|
|
logging.info(f"Generated {len(self.triplets)} triplets") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.triplets) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
if idx >= len(self.triplets): |
|
|
raise IndexError(f"Index {idx} out of range for dataset with {len(self.triplets)} triplets") |
|
|
|
|
|
anchor_idx, positive_idx, negative_idx = self.triplets[idx] |
|
|
return ( |
|
|
self.preprocessed_inputs[anchor_idx], |
|
|
self.preprocessed_inputs[positive_idx], |
|
|
self.preprocessed_inputs[negative_idx] |
|
|
) |
|
|
|
|
|
def _generate_triplets_gpu(self, num_triplets): |
|
|
"""Generate triplets using a more reliable approach with batch processing""" |
|
|
logging.info("Generating triplets with batch approach...") |
|
|
|
|
|
triplets = [] |
|
|
batch_size = 10 |
|
|
num_batches = (num_triplets + batch_size - 1) // batch_size |
|
|
|
|
|
progress_bar = tqdm( |
|
|
range(num_batches), |
|
|
desc="Generating triplet batches", |
|
|
bar_format='{l_bar}{bar:10}{r_bar}' |
|
|
) |
|
|
|
|
|
for _ in progress_bar: |
|
|
batch_triplets = [] |
|
|
|
|
|
|
|
|
for i in range(batch_size): |
|
|
if len(triplets) >= num_triplets: |
|
|
break |
|
|
|
|
|
|
|
|
anchor_idx = random.randint(0, len(self.users)-1) |
|
|
anchor_user = self.users[anchor_idx] |
|
|
|
|
|
|
|
|
|
|
|
candidates = set() |
|
|
for channel in self.pipeline._get_field_from_user(anchor_user, 'dmp_channels'): |
|
|
if channel is not None: |
|
|
candidates.update(self.channel_index.get(str(channel), [])) |
|
|
for cluster in self.pipeline._get_field_from_user(anchor_user, 'dmp_clusters'): |
|
|
if cluster is not None: |
|
|
candidates.update(self.cluster_index.get(str(cluster), [])) |
|
|
for tag in self.pipeline._get_field_from_user(anchor_user, 'dmp_tags'): |
|
|
if tag is not None and (self.filtered_tags is None or str(tag) in self.filtered_tags): |
|
|
candidates.update(self.tag_index.get(str(tag), [])) |
|
|
|
|
|
|
|
|
|
|
|
for brand in self.pipeline._get_field_from_user(anchor_user, 'dmp_brands'): |
|
|
if brand is not None: |
|
|
candidates.update(self.brands_index.get(str(brand), [])) |
|
|
for source in self.pipeline._get_field_from_user(anchor_user, 'source'): |
|
|
if source is not None: |
|
|
candidates.update(self.source_index.get(str(source), [])) |
|
|
for device in self.pipeline._get_field_from_user(anchor_user, 'device'): |
|
|
if device is not None: |
|
|
candidates.update(self.device_index.get(str(device), [])) |
|
|
|
|
|
|
|
|
candidates.discard(anchor_idx) |
|
|
|
|
|
|
|
|
if candidates: |
|
|
similarities = [] |
|
|
for idx in list(candidates)[:50]: |
|
|
sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline, self.filtered_tags) |
|
|
if sim > 0: |
|
|
similarities.append((idx, sim)) |
|
|
|
|
|
if similarities: |
|
|
similarities.sort(key=lambda x: x[1], reverse=True) |
|
|
top_k = min(10, len(similarities)) |
|
|
positive_idx = similarities[random.randint(0, top_k-1)][0] |
|
|
else: |
|
|
positive_idx = random.randint(0, len(self.users)-1) |
|
|
else: |
|
|
positive_idx = random.randint(0, len(self.users)-1) |
|
|
|
|
|
|
|
|
attempts = 0 |
|
|
negative_idx = None |
|
|
|
|
|
while attempts < 20 and negative_idx is None: |
|
|
idx = random.randint(0, len(self.users)-1) |
|
|
if idx != anchor_idx and idx != positive_idx: |
|
|
sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline, self.filtered_tags) |
|
|
if sim < 0.1: |
|
|
negative_idx = idx |
|
|
break |
|
|
attempts += 1 |
|
|
|
|
|
if negative_idx is None: |
|
|
negative_idx = random.randint(0, len(self.users)-1) |
|
|
|
|
|
batch_triplets.append((anchor_idx, positive_idx, negative_idx)) |
|
|
|
|
|
triplets.extend(batch_triplets) |
|
|
|
|
|
return triplets[:num_triplets] |
|
|
|
|
|
|
|
|
def collate_batch(batch): |
|
|
"""Custom collate function to properly handle tensor dimensions""" |
|
|
anchor_inputs, positive_inputs, negative_inputs = zip(*batch) |
|
|
|
|
|
def process_group_inputs(group_inputs): |
|
|
processed = {} |
|
|
for field in group_inputs[0].keys(): |
|
|
|
|
|
max_len = max(inputs[field].size(0) for inputs in group_inputs) |
|
|
|
|
|
|
|
|
padded = torch.stack([ |
|
|
torch.cat([ |
|
|
inputs[field], |
|
|
torch.zeros(max_len - inputs[field].size(0), dtype=torch.long) |
|
|
]) if inputs[field].size(0) < max_len else inputs[field][:max_len] |
|
|
for inputs in group_inputs |
|
|
]) |
|
|
|
|
|
processed[field] = padded |
|
|
|
|
|
return processed |
|
|
|
|
|
|
|
|
anchor_batch = process_group_inputs(anchor_inputs) |
|
|
positive_batch = process_group_inputs(positive_inputs) |
|
|
negative_batch = process_group_inputs(negative_inputs) |
|
|
|
|
|
return anchor_batch, positive_batch, negative_batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_user_embeddings(model, users_data, pipeline, num_epochs=10, batch_size=32, lr=0.001, save_dir=None, save_interval=2, num_triplets=150): |
|
|
"""Main training of the model with proper batch handling and incremental saving""" |
|
|
model.train() |
|
|
model.to(device) |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr) |
|
|
|
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.StepLR( |
|
|
optimizer, |
|
|
step_size=2, |
|
|
gamma=0.9 |
|
|
) |
|
|
|
|
|
|
|
|
num_cpu_cores = max(1, min(32, os.cpu_count())) |
|
|
logging.info(f"Using {num_cpu_cores} CPU cores for data processing") |
|
|
|
|
|
|
|
|
dataset = UserSimilarityDataset( |
|
|
pipeline, |
|
|
users_data, |
|
|
num_triplets=num_triplets, |
|
|
num_workers=num_cpu_cores |
|
|
) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
collate_fn=collate_batch, |
|
|
num_workers=0, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
|
|
|
criterion = torch.nn.TripletMarginLoss(margin=1.0) |
|
|
|
|
|
|
|
|
epoch_pbar = tqdm( |
|
|
range(num_epochs), |
|
|
desc="Training Progress", |
|
|
bar_format='{l_bar}{bar:10}{r_bar}' |
|
|
) |
|
|
|
|
|
try: |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
log_dir = Path(save_dir) / "logs" if save_dir else Path("./logs") |
|
|
log_dir.mkdir(exist_ok=True, parents=True) |
|
|
writer = SummaryWriter(log_dir=log_dir) |
|
|
tensorboard_available = True |
|
|
except ImportError: |
|
|
logging.warning("TensorBoard not available, skipping logging") |
|
|
tensorboard_available = False |
|
|
|
|
|
for epoch in epoch_pbar: |
|
|
total_loss = 0 |
|
|
num_batches = 0 |
|
|
|
|
|
|
|
|
|
|
|
total_batches = len(dataloader) |
|
|
update_freq = max(1, total_batches // 10) |
|
|
batch_pbar = tqdm( |
|
|
dataloader, |
|
|
desc=f"Epoch {epoch+1}/{num_epochs}", |
|
|
leave=False, |
|
|
miniters=update_freq, |
|
|
bar_format='{l_bar}{bar:10}{r_bar}', |
|
|
disable=True |
|
|
) |
|
|
|
|
|
|
|
|
epoch_progress = tqdm( |
|
|
total=len(dataloader), |
|
|
desc=f"Epoch {epoch+1}/{num_epochs}", |
|
|
leave=True, |
|
|
bar_format='{l_bar}{bar:10}{r_bar}' |
|
|
) |
|
|
|
|
|
|
|
|
for batch_idx, batch_inputs in enumerate(dataloader): |
|
|
try: |
|
|
|
|
|
anchor_batch, positive_batch, negative_batch = batch_inputs |
|
|
|
|
|
|
|
|
anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()} |
|
|
positive_batch = {k: v.to(device) for k, v in positive_batch.items()} |
|
|
negative_batch = {k: v.to(device) for k, v in negative_batch.items()} |
|
|
|
|
|
|
|
|
anchor_emb = model(anchor_batch) |
|
|
positive_emb = model(positive_batch) |
|
|
negative_emb = model(negative_batch) |
|
|
|
|
|
|
|
|
loss = criterion(anchor_emb, positive_emb, negative_emb) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
|
|
|
update_interval = max(1, len(dataloader) // 10) |
|
|
if (batch_idx + 1) % update_interval == 0 or batch_idx == len(dataloader) - 1: |
|
|
|
|
|
remaining = min(update_interval, len(dataloader) - epoch_progress.n) |
|
|
epoch_progress.update(remaining) |
|
|
|
|
|
current_avg_loss = total_loss / num_batches |
|
|
epoch_progress.set_postfix(avg_loss=f"{current_avg_loss:.4f}", |
|
|
last_batch_loss=f"{loss.item():.4f}") |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error during batch processing: {str(e)}") |
|
|
logging.error(f"Batch details: {str(e.__class__.__name__)}") |
|
|
continue |
|
|
|
|
|
|
|
|
epoch_progress.close() |
|
|
|
|
|
avg_loss = total_loss / max(1, num_batches) |
|
|
|
|
|
epoch_pbar.set_postfix(avg_loss=f"{avg_loss:.4f}") |
|
|
|
|
|
|
|
|
if tensorboard_available: |
|
|
writer.add_scalar('Loss/train', avg_loss, epoch) |
|
|
|
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
if save_dir and (epoch + 1) % save_interval == 0: |
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'loss': avg_loss, |
|
|
'scheduler_state_dict': scheduler.state_dict() |
|
|
} |
|
|
|
|
|
save_path = Path(save_dir) / f'model_checkpoint_epoch_{epoch+1}.pth' |
|
|
torch.save(checkpoint, save_path) |
|
|
logging.info(f"Checkpoint saved at epoch {epoch+1}: {save_path}") |
|
|
|
|
|
if tensorboard_available: |
|
|
writer.close() |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def compute_tag_frequencies(pipeline, users_data): |
|
|
""" |
|
|
Calcola le frequenze dei tag nel dataset. |
|
|
|
|
|
Args: |
|
|
pipeline: Pipeline di embedding |
|
|
users_data: Lista di utenti |
|
|
|
|
|
Returns: |
|
|
dict: Dizionario con tag come chiavi e frequenze come valori |
|
|
int: Numero totale di tag processati |
|
|
""" |
|
|
logging.info("Calcolando frequenze dei tag...") |
|
|
tag_frequencies = {} |
|
|
total_tags = 0 |
|
|
|
|
|
|
|
|
for data in users_data: |
|
|
|
|
|
if 'raw_json' in data and 'user' in data['raw_json']: |
|
|
user = data['raw_json']['user'] |
|
|
elif 'user' in data: |
|
|
user = data['user'] |
|
|
else: |
|
|
user = data |
|
|
|
|
|
|
|
|
tags = pipeline._get_field_from_user(user, 'dmp_tags') |
|
|
for tag in tags: |
|
|
if tag is not None: |
|
|
tag_str = str(tag).lower().strip() |
|
|
tag_frequencies[tag_str] = tag_frequencies.get(tag_str, 0) + 1 |
|
|
total_tags += 1 |
|
|
|
|
|
logging.info(f"Trovati {len(tag_frequencies)} tag unici su {total_tags} occorrenze totali") |
|
|
return tag_frequencies, total_tags |
|
|
|
|
|
|
|
|
|
|
|
def filter_tags_by_criteria(tag_frequencies, min_frequency=100, percentile=None): |
|
|
""" |
|
|
Filtra i tag in base a criteri di frequenza o percentile. |
|
|
|
|
|
Args: |
|
|
tag_frequencies: Dizionario con tag e frequenze |
|
|
min_frequency: Frequenza minima richiesta (default: 100) |
|
|
percentile: Se specificato, mantiene solo i tag fino al percentile indicato |
|
|
|
|
|
Returns: |
|
|
set: Set di tag che soddisfano i criteri |
|
|
""" |
|
|
if percentile is not None: |
|
|
|
|
|
sorted_tags = sorted(tag_frequencies.items(), key=lambda x: x[1], reverse=True) |
|
|
cutoff_index = int(len(sorted_tags) * (percentile / 100.0)) |
|
|
filtered_tags = {tag for tag, _ in sorted_tags[:cutoff_index]} |
|
|
|
|
|
min_freq_in_set = sorted_tags[cutoff_index-1][1] if cutoff_index > 0 else 0 |
|
|
logging.info(f"Filtrati tag al {percentile}° percentile. Mantenuti {len(filtered_tags)} tag con frequenza >= {min_freq_in_set}") |
|
|
else: |
|
|
|
|
|
filtered_tags = {tag for tag, freq in tag_frequencies.items() if freq >= min_frequency} |
|
|
logging.info(f"Filtrati tag con frequenza < {min_frequency}. Mantenuti {len(filtered_tags)} tag") |
|
|
|
|
|
return filtered_tags |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
output_dir = Path(OUTPUT_DIR) |
|
|
|
|
|
|
|
|
cuda_available = torch.cuda.is_available() |
|
|
logging.info(f"CUDA available: {cuda_available}") |
|
|
if cuda_available: |
|
|
logging.info(f"CUDA device: {torch.cuda.get_device_name(0)}") |
|
|
logging.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
|
|
|
|
|
|
|
|
cpu_count = os.cpu_count() |
|
|
memory_info = psutil.virtual_memory() |
|
|
logging.info(f"CPU cores: {cpu_count}") |
|
|
logging.info(f"System memory: {memory_info.total / 1e9:.2f} GB") |
|
|
|
|
|
|
|
|
logging.info("Running with the following configuration:") |
|
|
logging.info(f"- Number of triplets: {NUM_TRIPLETS}") |
|
|
logging.info(f"- Number of epochs: {NUM_EPOCHS}") |
|
|
logging.info(f"- Batch size: {BATCH_SIZE}") |
|
|
logging.info(f"- Learning rate: {LEARNING_RATE}") |
|
|
logging.info(f"- Output dimension: {OUTPUT_DIM}") |
|
|
logging.info(f"- Data path: {DATA_PATH}") |
|
|
logging.info(f"- Output directory: {OUTPUT_DIR}") |
|
|
|
|
|
|
|
|
logging.info("Loading user data...") |
|
|
try: |
|
|
try: |
|
|
|
|
|
with open(DATA_PATH, 'r') as f: |
|
|
json_data = json.load(f) |
|
|
|
|
|
|
|
|
if isinstance(json_data, list): |
|
|
users_data = json_data |
|
|
elif isinstance(json_data, dict): |
|
|
|
|
|
users_data = [json_data] |
|
|
else: |
|
|
raise ValueError("Unsupported JSON format") |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
|
|
|
logging.info("Detected possible non-standard JSON format, attempting correction...") |
|
|
with open(DATA_PATH, 'r') as f: |
|
|
text = f.read().strip() |
|
|
|
|
|
|
|
|
if not text.startswith('['): |
|
|
text = '[' + text |
|
|
if not text.endswith(']'): |
|
|
text = text + ']' |
|
|
|
|
|
|
|
|
users_data = json.loads(text) |
|
|
logging.info("JSON format successfully corrected") |
|
|
|
|
|
logging.info(f"Loaded {len(users_data)} records") |
|
|
except FileNotFoundError: |
|
|
logging.error(f"File {DATA_PATH} not found!") |
|
|
return |
|
|
except Exception as e: |
|
|
logging.error(f"Unable to load file: {str(e)}") |
|
|
return |
|
|
|
|
|
|
|
|
logging.info("Initializing pipeline...") |
|
|
pipeline = UserEmbeddingPipeline( |
|
|
output_dim=OUTPUT_DIM, |
|
|
max_sequence_length=MAX_SEQ_LENGTH |
|
|
) |
|
|
|
|
|
|
|
|
logging.info("Building vocabularies...") |
|
|
try: |
|
|
pipeline.build_vocabularies(users_data) |
|
|
vocab_sizes = {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()} |
|
|
logging.info(f"Vocabulary sizes: {vocab_sizes}") |
|
|
except Exception as e: |
|
|
logging.error(f"Error building vocabularies: {str(e)}") |
|
|
return |
|
|
|
|
|
|
|
|
logging.info("Initializing model...") |
|
|
try: |
|
|
pipeline.initialize_model() |
|
|
logging.info("Model initialized successfully") |
|
|
except Exception as e: |
|
|
logging.error(f"Error initializing model: {str(e)}") |
|
|
return |
|
|
|
|
|
|
|
|
logging.info("Starting training...") |
|
|
try: |
|
|
|
|
|
model_dir = output_dir / "model_checkpoints" |
|
|
model_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
model = train_user_embeddings( |
|
|
pipeline.model, |
|
|
users_data, |
|
|
pipeline, |
|
|
num_epochs=NUM_EPOCHS, |
|
|
batch_size=BATCH_SIZE, |
|
|
lr=LEARNING_RATE, |
|
|
save_dir=model_dir, |
|
|
save_interval=SAVE_INTERVAL, |
|
|
num_triplets=NUM_TRIPLETS |
|
|
) |
|
|
logging.info("Training completed") |
|
|
pipeline.model = model |
|
|
|
|
|
|
|
|
logging.info("Saving model...") |
|
|
|
|
|
|
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
model_path = output_dir / 'model.pth' |
|
|
|
|
|
|
|
|
checkpoint = { |
|
|
'model_state_dict': pipeline.model.state_dict(), |
|
|
'vocab_maps': pipeline.vocab_maps, |
|
|
'embedding_dims': pipeline.embedding_dims, |
|
|
'output_dim': pipeline.output_dim, |
|
|
'max_sequence_length': pipeline.max_sequence_length |
|
|
} |
|
|
|
|
|
|
|
|
torch.save(checkpoint, model_path) |
|
|
|
|
|
logging.info(f"Model saved to: {model_path}") |
|
|
|
|
|
|
|
|
config_info = { |
|
|
'model_type': 'UserEmbeddingModel', |
|
|
'vocab_sizes': {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()}, |
|
|
'embedding_dims': pipeline.embedding_dims, |
|
|
'output_dim': pipeline.output_dim, |
|
|
'max_sequence_length': pipeline.max_sequence_length, |
|
|
'padded_fields': list(pipeline.model.padded_fields), |
|
|
'fields': pipeline.fields |
|
|
} |
|
|
|
|
|
config_path = output_dir / 'model_config.json' |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(config_info, f, indent=2) |
|
|
|
|
|
logging.info(f"Model configuration saved to: {config_path}") |
|
|
|
|
|
|
|
|
save_hf = os.environ.get("SAVE_HF_FORMAT", "false").lower() == "true" |
|
|
if save_hf: |
|
|
logging.info("Saving in HuggingFace format...") |
|
|
|
|
|
hf_dir = output_dir / 'huggingface' |
|
|
hf_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
torch.save(pipeline.model.state_dict(), hf_dir / 'pytorch_model.bin') |
|
|
|
|
|
|
|
|
with open(hf_dir / 'config.json', 'w') as f: |
|
|
json.dump(config_info, f, indent=2) |
|
|
|
|
|
logging.info(f"Model saved in HuggingFace format to: {hf_dir}") |
|
|
|
|
|
|
|
|
hf_repo_id = os.environ.get("HF_REPO_ID") |
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
if save_hf and hf_repo_id and hf_token: |
|
|
try: |
|
|
from huggingface_hub import HfApi |
|
|
|
|
|
logging.info(f"Pushing model to HuggingFace: {hf_repo_id}") |
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
api.create_repo( |
|
|
repo_id=hf_repo_id, |
|
|
token=hf_token, |
|
|
exist_ok=True, |
|
|
private=True |
|
|
) |
|
|
|
|
|
|
|
|
for file_path in (output_dir / "huggingface").glob("**/*"): |
|
|
if file_path.is_file(): |
|
|
api.upload_file( |
|
|
path_or_fileobj=str(file_path), |
|
|
path_in_repo=file_path.relative_to(output_dir / "huggingface"), |
|
|
repo_id=hf_repo_id, |
|
|
token=hf_token |
|
|
) |
|
|
|
|
|
logging.info(f"Model successfully pushed to HuggingFace: {hf_repo_id}") |
|
|
except Exception as e: |
|
|
logging.error(f"Error pushing to HuggingFace: {str(e)}") |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error during training or saving: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return |
|
|
|
|
|
logging.info("Process completed successfully!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |