bach-or-bot / src /llm2vectrain /llm2vec_trainer.py
Acelle Krislette Rosales
Initial commit: Added application code
fc7b4a9
raw
history blame
5.35 kB
from sklearn.decomposition import IncrementalPCA
from sklearn.preprocessing import StandardScaler
from pathlib import Path
import numpy as np
import pickle
import torch
import os
import joblib
# Initialize PCA and StandardScaler globally for training
_pca_trainer = None
class SimplePCATrainer:
"""
A simple PCA trainer that uses IncrementalPCA to fit data in batches.
It saves checkpoints every 5 batches and can save the final model.
Args:
None
Returns:
None
Attributes:
pca: The IncrementalPCA model.
scaler: StandardScaler for normalizing data.
fitted: Boolean indicating if the model has been initialized.
batch_count_pca: Counter for the number of batches processed.
Methods:
process_batch(vectors): Processes a batch of vectors, fits the PCA model incrementally.
save_final(model_path): Saves the final PCA model to the specified path.
"""
# Initialize the trainer
def __init__(self):
self.pca = None
self.scaler = StandardScaler()
self.fitted = False
self.batch_count_pca = 0
def _determine_optimal_components(self, vectors):
"""
Determine the optimal number of PCA components to retain 95% variance.
Args:
vectors: The input data to analyze.
Returns:
n_components: The optimal number of components.
"""
temp_pca = IncrementalPCA()
temp_pca.fit(vectors)
cumsum_var = np.cumsum(temp_pca.explained_variance_ratio_)
n_comp_95 = np.argmax(cumsum_var >= 0.95) + 1
return min(n_comp_95, vectors.shape[1] // 2)
def process_batch(self, vectors):
"""
Process a batch of vectors, fitting the PCA model incrementally.
Args:
vectors: The input data batch to process.
Returns:
reduced_vectors: The PCA-transformed data.
Note: This method saves a checkpoint every 5 batches.
"""
if not self.fitted:
# First batch - initialize everything
n_components = self._determine_optimal_components(vectors)
self.pca = IncrementalPCA(n_components=n_components, batch_size=1000)
self.scaler.fit(vectors)
self.fitted = True
print(f"Initialized PCA with {n_components} components")
# Process batch
vectors_scaled = self.scaler.transform(vectors)
self.pca.partial_fit(vectors_scaled)
reduced_vectors = self.pca.transform(vectors_scaled)
self.batch_count_pca += 1
# Save checkpoint every 5 batches
if self.batch_count_pca % 5 == 0:
os.makedirs("pca_checkpoints", exist_ok=True)
with open(f"pca_checkpoints/checkpoint_batch_{self.batch_count_pca}.pkl", 'wb') as f:
pickle.dump({'pca': self.pca, 'scaler': self.scaler}, f)
print(f"Saved checkpoint at batch {self.batch_count_pca}")
print(f"Processed batch {self.batch_count_pca}, shape: {vectors.shape} -> {reduced_vectors.shape}")
return reduced_vectors
def save_final(self, model_path):
"""
Save the final PCA model to the specified path.
Args:
model_path: The file path to save the PCA model.
Returns:
None
Note: Change the model path as needed in the data_config.yml file.
"""
os.makedirs(os.path.dirname(model_path), exist_ok=True)
with open(model_path, 'wb') as f:
pickle.dump({'pca': self.pca, 'scaler': self.scaler}, f)
print(f"Final model saved to {model_path}. Total variance explained: {np.sum(self.pca.explained_variance_ratio_):.4f}")
## For Single Input
def load_pca_model(vectors, model_path="models/fusion/pca.pkl"):
"""
Load a pre-trained PCA model and transform the input vectors.
Args:
vectors: The input data to transform.
model_path: The file path of the pre-trained PCA model.
Returns:
output: The PCA-transformed data.
Note: Change the model path as needed in the data_config.yml file (or set the path file as shown above). Can be used for the main program.
"""
model_path = Path(model_path)
pca = joblib.load(model_path)
return pca.transform(vectors)
def l2vec_single_train(l2v, lyrics):
"""
Encode a single lyric string using the provided LLM2Vec model.
Args:
l2v: The LLM2Vec model for encoding lyrics.
lyrics: A single lyric string to encode.
Returns:
vectors: The vector representation of the lyrics.
"""
vectors = l2v.encode([lyrics]).detach().cpu().numpy()
return vectors
# For Batch Processing
def l2vec_train(l2v, lyrics_list):
"""
Encode a list of lyric strings using the provided LLM2Vec model.
Args:
l2v: The LLM2Vec model for encoding lyrics.
lyrics_list: A list of lyric strings to encode.
Returns:
vectors: The encoded vector representations of the lyrics.
Note: This function only encodes the lyrics and does not apply PCA reduction. The PCA reduction can be applied separately in the train.py module.
"""
with torch.no_grad():
vectors = l2v.encode(lyrics_list) # lyrics_list: list of strings
return vectors