Spaces:
Sleeping
Sleeping
| import io | |
| import json | |
| import re | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from PIL import Image | |
| from transformers import AutoTokenizer | |
| tokenizers = [ | |
| "google/gemma-7b", | |
| "meta-llama/Llama-2-7b", | |
| "mistralai/Mistral-7B-v0.1", | |
| "facebook/opt-2.7b", | |
| "microsoft/phi-2", | |
| "THUDM/chatglm3-6b", | |
| "Qwen/Qwen1.5-7B-Chat", | |
| "bigscience/bloom-560m", | |
| "ise-uiuc/Magicoder-S-DS-6.7B", | |
| "google/flan-t5-base", | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| ] | |
| def plot_histogram(data): | |
| plt.hist(data) | |
| plt.title("Histogram of number of tokens per dataset item") | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| im = Image.open(buf) | |
| return im | |
| def count(model_id, dataset_id, config, split, column, add_special_tokens=True): | |
| tokencounter = [] | |
| wordcounter = [] | |
| charcounter = [] | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| if config == "": | |
| config is None | |
| dataset = load_dataset(dataset_id, config, split=split, trust_remote_code=True) | |
| pattern = r"[a-zA-Z]+" | |
| for item in dataset: | |
| tokens = tokenizer(item[column], add_special_tokens=add_special_tokens)["input_ids"] | |
| tokencounter.append(len(tokens)) | |
| charcounter.append(len(item[column])) | |
| # not 100% accurate but good enough | |
| words = re.findall(pattern, item[column]) | |
| wordcounter.append(len(words)) | |
| df = pd.DataFrame(tokencounter).describe().T | |
| df.insert(0, "type", "tokens") | |
| dfc = pd.DataFrame(charcounter).describe().T | |
| dfc.insert(0, "type", "chars") | |
| dfw = pd.DataFrame(wordcounter).describe().T | |
| dfw.insert(0, "type", "words") | |
| df.loc[-1] = dfw.values[0] | |
| df.index = df.index + 1 # shifting index | |
| df.loc[-1] = dfc.values[0] | |
| df = df.round(1) | |
| df.drop("count", axis=1, inplace=True) | |
| return plot_histogram(tokencounter), df | |
| demo = gr.Interface( | |
| fn=count, | |
| title="Dataset token counts and distribution", | |
| inputs=[ | |
| gr.Dropdown(label="Tokenizer", choices=tokenizers, allow_custom_value=True), | |
| gr.Textbox(label="Dataset"), | |
| gr.Textbox(label="Config"), | |
| gr.Textbox(label="Split"), | |
| gr.Textbox(label="Column"), | |
| gr.Checkbox(label="Add special tokens", value=True), | |
| ], | |
| outputs=[ | |
| gr.Image(), | |
| gr.Dataframe(label="Token, word and character counts per dataset item"), | |
| ], | |
| examples=[ | |
| ["mistralai/Mistral-7B-v0.1", "gsarti/flores_101", "eng", "dev", "sentence"], | |
| ["mistralai/Mistral-7B-v0.1", "Muennighoff/flores200", "eng_Latn", "dev", "sentence"], | |
| ["mistralai/Mistral-7B-v0.1", "wikitext", "wikitext-2-v1", "validation", "text"], | |
| ["mistralai/Mistral-7B-v0.1", "hails/mmlu_no_train", "elementary_mathematics", "test", "question"], | |
| ["mistralai/Mistral-7B-v0.1", "imdb", "", "test", "text"], | |
| ["mistralai/Mistral-7B-v0.1", "gsm8k", "main", "test", "question"], | |
| ["mistralai/Mistral-7B-v0.1", "locuslab/TOFU", "world_facts", "train", "question"], | |
| ], | |
| cache_examples=False | |
| ) | |
| demo.launch() | |