File size: 5,352 Bytes
fc7b4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from sklearn.decomposition import IncrementalPCA
from sklearn.preprocessing import StandardScaler
from pathlib import Path

import numpy as np
import pickle
import torch
import os
import joblib

# Initialize PCA and StandardScaler globally for training
_pca_trainer = None

class SimplePCATrainer:
    """
    A simple PCA trainer that uses IncrementalPCA to fit data in batches.
    It saves checkpoints every 5 batches and can save the final model.
    
    Args:
        None

    Returns:
        None

    Attributes:
        pca: The IncrementalPCA model.
        scaler: StandardScaler for normalizing data.
        fitted: Boolean indicating if the model has been initialized.
        batch_count_pca: Counter for the number of batches processed.

    Methods:
        process_batch(vectors): Processes a batch of vectors, fits the PCA model incrementally.
        save_final(model_path): Saves the final PCA model to the specified path.
    """

    # Initialize the trainer
    def __init__(self):
        self.pca = None
        self.scaler = StandardScaler()
        self.fitted = False
        self.batch_count_pca = 0

    def _determine_optimal_components(self, vectors):
        """
        Determine the optimal number of PCA components to retain 95% variance.
        
        Args:
            vectors: The input data to analyze.
        Returns:
            n_components: The optimal number of components.
        """
        temp_pca = IncrementalPCA()
        temp_pca.fit(vectors)
        cumsum_var = np.cumsum(temp_pca.explained_variance_ratio_)
        n_comp_95 = np.argmax(cumsum_var >= 0.95) + 1
        return min(n_comp_95, vectors.shape[1] // 2)

    def process_batch(self, vectors):
        """
        Process a batch of vectors, fitting the PCA model incrementally.
        
        Args:
            vectors: The input data batch to process.
        Returns:
            reduced_vectors: The PCA-transformed data.

        Note: This method saves a checkpoint every 5 batches.
        """
        if not self.fitted:
            # First batch - initialize everything
            n_components = self._determine_optimal_components(vectors)
            self.pca = IncrementalPCA(n_components=n_components, batch_size=1000)
            self.scaler.fit(vectors)
            self.fitted = True
            print(f"Initialized PCA with {n_components} components")

        # Process batch
        vectors_scaled = self.scaler.transform(vectors)
        self.pca.partial_fit(vectors_scaled)
        reduced_vectors = self.pca.transform(vectors_scaled)

        self.batch_count_pca += 1

        # Save checkpoint every 5 batches
        if self.batch_count_pca % 5 == 0:
            os.makedirs("pca_checkpoints", exist_ok=True)
            with open(f"pca_checkpoints/checkpoint_batch_{self.batch_count_pca}.pkl", 'wb') as f:
                pickle.dump({'pca': self.pca, 'scaler': self.scaler}, f)
            print(f"Saved checkpoint at batch {self.batch_count_pca}")

        print(f"Processed batch {self.batch_count_pca}, shape: {vectors.shape} -> {reduced_vectors.shape}")
        return reduced_vectors

    def save_final(self, model_path):
        """
        Save the final PCA model to the specified path.

        Args:
            model_path: The file path to save the PCA model.

        Returns:
            None
        
        Note: Change the model path as needed in the data_config.yml file.
        """
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        with open(model_path, 'wb') as f:
            pickle.dump({'pca': self.pca, 'scaler': self.scaler}, f)
        print(f"Final model saved to {model_path}. Total variance explained: {np.sum(self.pca.explained_variance_ratio_):.4f}")

## For Single Input
def load_pca_model(vectors, model_path="models/fusion/pca.pkl"):
    """
    Load a pre-trained PCA model and transform the input vectors.

    Args:
        vectors: The input data to transform.
        model_path: The file path of the pre-trained PCA model.

    Returns:
        output: The PCA-transformed data.

    Note: Change the model path as needed in the data_config.yml file (or set the path file as shown above). Can be used for the main program.
    """
    model_path = Path(model_path)
    pca = joblib.load(model_path)
    return pca.transform(vectors)

def l2vec_single_train(l2v, lyrics):
    """
    Encode a single lyric string using the provided LLM2Vec model.
    
    Args:
        l2v: The LLM2Vec model for encoding lyrics.
        lyrics: A single lyric string to encode.
    
    Returns:
        vectors: The vector representation of the lyrics.

    """
    vectors = l2v.encode([lyrics]).detach().cpu().numpy()
    return vectors

# For Batch Processing
def l2vec_train(l2v, lyrics_list):
    """
    Encode a list of lyric strings using the provided LLM2Vec model.

    Args:
        l2v: The LLM2Vec model for encoding lyrics.
        lyrics_list: A list of lyric strings to encode.
    Returns:
        vectors: The encoded vector representations of the lyrics.

    Note: This function only encodes the lyrics and does not apply PCA reduction. The PCA reduction can be applied separately in the train.py module.
    """
    with torch.no_grad():
        vectors = l2v.encode(lyrics_list)  # lyrics_list: list of strings
    return vectors