| | import hdbscan |
| | import numpy as np |
| |
|
| | |
| | def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): |
| | """ Function used to select the HDBSCAN-like model for generating |
| | predictions and probabilities. |
| | |
| | Arguments: |
| | model: The cluster model. |
| | func: The function to use. Options: |
| | - "approximate_predict" |
| | - "all_points_membership_vectors" |
| | - "membership_vector" |
| | embeddings: Input embeddings for "approximate_predict" |
| | and "membership_vector" |
| | """ |
| |
|
| | |
| | if func == "approximate_predict": |
| | if isinstance(model, hdbscan.HDBSCAN): |
| | predictions, probabilities = hdbscan.approximate_predict(model, embeddings) |
| | return predictions, probabilities |
| |
|
| | str_type_model = str(type(model)).lower() |
| | if "cuml" in str_type_model and "hdbscan" in str_type_model: |
| | from cuml.cluster import hdbscan as cuml_hdbscan |
| | predictions, probabilities = cuml_hdbscan.approximate_predict(model, embeddings) |
| | return predictions, probabilities |
| |
|
| | predictions = model.predict(embeddings) |
| | return predictions, None |
| |
|
| | |
| | if func == "all_points_membership_vectors": |
| | if isinstance(model, hdbscan.HDBSCAN): |
| | return hdbscan.all_points_membership_vectors(model) |
| |
|
| | str_type_model = str(type(model)).lower() |
| | if "cuml" in str_type_model and "hdbscan" in str_type_model: |
| | from cuml.cluster import hdbscan as cuml_hdbscan |
| | return cuml_hdbscan.all_points_membership_vectors(model) |
| |
|
| | return None |
| | |
| | |
| | if func == "membership_vector": |
| | if isinstance(model, hdbscan.HDBSCAN): |
| | probabilities = hdbscan.membership_vector(model, embeddings) |
| | return probabilities |
| |
|
| | str_type_model = str(type(model)).lower() |
| | if "cuml" in str_type_model and "hdbscan" in str_type_model: |
| | from cuml.cluster.hdbscan.prediction import approximate_predict |
| | probabilities = approximate_predict(model, embeddings) |
| | return probabilities |
| |
|
| | return None |
| |
|
| |
|
| | def is_supported_hdbscan(model): |
| | """ Check whether the input model is a supported HDBSCAN-like model """ |
| | if isinstance(model, hdbscan.HDBSCAN): |
| | return True |
| |
|
| | str_type_model = str(type(model)).lower() |
| | if "cuml" in str_type_model and "hdbscan" in str_type_model: |
| | return True |
| |
|
| | return False |
| |
|