File size: 571 Bytes
45fe8b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from sklearn.mixture import GaussianMixture
import pickle


def build_gmm_clusters(embeddings, n_clusters=12):

    gmm = GaussianMixture(
        n_components=n_clusters,
        covariance_type="full",
        random_state=42
    )

    gmm.fit(embeddings)

    cluster_probs = gmm.predict_proba(embeddings)

    return gmm, cluster_probs


def save_gmm_model(gmm, path="models/gmm_model.pkl"):

    with open(path, "wb") as f:
        pickle.dump(gmm, f)


def load_gmm_model(path="models/gmm_model.pkl"):

    with open(path, "rb") as f:
        return pickle.load(f)