File size: 4,005 Bytes
86e673e
 
 
 
7b62017
86e673e
 
 
 
7b62017
 
86e673e
 
7b62017
 
 
 
 
86e673e
 
 
 
 
 
 
 
 
 
 
7b62017
 
 
 
 
 
 
 
 
 
86e673e
7b62017
 
 
 
 
 
 
86e673e
 
 
 
 
7b62017
 
 
 
 
 
 
86e673e
 
 
 
 
 
 
7b62017
 
 
86e673e
 
 
 
 
7b62017
86e673e
 
 
 
7b62017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86e673e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import logging
from functools import partial
from typing import Optional

import pandas as pd
import typer
from bokeh.plotting import output_file as bokeh_output_file
from bokeh.plotting import save
from embedding_lenses.data import uploaded_file_to_dataframe
from embedding_lenses.dimensionality_reduction import (get_tsne_embeddings,
                                                       get_umap_embeddings)
from embedding_lenses.embedding import load_model

from perplexity_lenses.data import (documents_df_to_sentences_df,
                                    hub_dataset_to_dataframe)
from perplexity_lenses.engine import (DIMENSIONALITY_REDUCTION_ALGORITHMS,
                                      DOCUMENT_TYPES, EMBEDDING_MODELS,
                                      LANGUAGES, SEED, generate_plot)
from perplexity_lenses.perplexity import KenlmModel

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


app = typer.Typer()


@app.command()
def main(
    dataset: str = typer.Option(
        "mc4", help="The name of the hub dataset or local csv/tsv file."
    ),
    dataset_config: Optional[str] = typer.Option(
        "es",
        help="The configuration of the hub dataset, if any. Does not apply to local csv/tsv files.",
    ),
    dataset_split: Optional[str] = typer.Option(
        "train", help="The dataset split. Does not apply to local csv/tsv files."
    ),
    text_column: str = typer.Option("text", help="The text field name."),
    language: str = typer.Option(
        "es", help=f"The language of the text. Options: {LANGUAGES}"
    ),
    doc_type: str = typer.Option(
        "sentence",
        help=f"Whether to embed at the sentence or document level. Options: {DOCUMENT_TYPES}.",
    ),
    sample: int = typer.Option(1000, help="Maximum number of examples to use."),
    dimensionality_reduction: str = typer.Option(
        DIMENSIONALITY_REDUCTION_ALGORITHMS[0],
        help=f"Whether to use UMAP or t-SNE for dimensionality reduction. Options: {DIMENSIONALITY_REDUCTION_ALGORITHMS}.",
    ),
    model_name: str = typer.Option(
        EMBEDDING_MODELS[0],
        help=f"The sentence embedding model to use. Options: {EMBEDDING_MODELS}",
    ),
    output_file: str = typer.Option(
        "perplexity.html", help="The name of the output visualization HTML file."
    ),
):
    """
    Perplexity Lenses: Visualize text embeddings in 2D using colors to represent perplexity values.
    """
    logger.info("Loading embedding model...")
    model = load_model(model_name)
    dimensionality_reduction_function = (
        partial(get_umap_embeddings, random_state=SEED)
        if dimensionality_reduction.lower() == "umap"
        else partial(get_tsne_embeddings, random_state=SEED)
    )
    logger.info("Loading KenLM model...")
    kenlm_model = KenlmModel.from_pretrained(language)
    logger.info("Loading dataset...")
    if dataset.endswith(".csv") or dataset.endswith(".tsv"):
        df = pd.read_csv(dataset, sep="\t" if dataset.endswith(".tsv") else ",")
        if doc_type.lower() == "sentence":
            df = documents_df_to_sentences_df(df, text_column, sample, seed=SEED)
        df["perplexity"] = df[text_column].map(kenlm_model.get_perplexity)
    else:
        df = hub_dataset_to_dataframe(
            dataset,
            dataset_config,
            dataset_split,
            sample,
            text_column,
            kenlm_model,
            seed=SEED,
            doc_type=doc_type,
        )
    # Round perplexity
    df["perplexity"] = df["perplexity"].round().astype(int)
    logger.info(
        f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
    )
    plot = generate_plot(
        df,
        text_column,
        "perplexity",
        None,
        dimensionality_reduction_function,
        model,
        seed=SEED,
    )
    logger.info("Saving plot")
    bokeh_output_file(output_file)
    save(plot)
    logger.info("Done")


if __name__ == "__main__":
    app()