Spaces:
Runtime error
Runtime error
| import requests | |
| from collections import Counter | |
| from requests.adapters import HTTPAdapter, Retry | |
| import gradio as gr | |
| import pandas as pd | |
| import polars as pl | |
| import spaces | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from huggingface_hub import PyTorchModelHubMixin | |
| import torch | |
| from torch import nn | |
| from transformers import AutoModel, AutoTokenizer, AutoConfig | |
| session = requests.Session() | |
| retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504]) | |
| session.mount('http://', HTTPAdapter(max_retries=retries)) | |
| class QualityModel(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, config): | |
| super(QualityModel, self).__init__() | |
| self.model = AutoModel.from_pretrained(config["base_model"]) | |
| self.dropout = nn.Dropout(config["fc_dropout"]) | |
| self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"])) | |
| def forward(self, input_ids, attention_mask): | |
| features = self.model( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| ).last_hidden_state | |
| dropped = self.dropout(features) | |
| outputs = self.fc(dropped) | |
| return torch.softmax(outputs[:, 0, :], dim=1) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta") | |
| tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta") | |
| model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device) | |
| model.eval() | |
| def predict(texts: list[str]): | |
| inputs = tokenizer( | |
| texts, return_tensors="pt", padding="longest", truncation=True | |
| ).to(device) | |
| outputs = model(inputs["input_ids"], inputs["attention_mask"]) | |
| predicted_classes = torch.argmax(outputs, dim=1) | |
| predicted_domains = [ | |
| config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy() | |
| ] | |
| return predicted_domains | |
| def plot_and_df(texts, preds): | |
| texts_df = pd.DataFrame({"quality": preds, "text": texts}) | |
| counts = Counter(preds) | |
| counts_df = pd.DataFrame( | |
| { | |
| "quality": ["Low", "Medium", "High"], | |
| "count": [counts.get("Low", 0), counts.get("Medium", 0), counts.get("High", 0)] | |
| } | |
| ) | |
| # counts.reset_index(inplace=True) | |
| return ( | |
| gr.BarPlot(counts_df, x="quality", y="count"), | |
| texts_df[texts_df["quality"] == "Low"][["text"]][:20], | |
| texts_df[texts_df["quality"] == "Medium"][["text"]][:20], | |
| texts_df[texts_df["quality"] == "High"][["text"]][:20], | |
| ) | |
| def run_quality_check(dataset, column, batch_size, num_examples): | |
| # config = "default" | |
| info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json() | |
| if "error" in info_resp: | |
| yield "β " + info_resp["error"], gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame() | |
| return | |
| config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"])) | |
| split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next( | |
| iter(info_resp["dataset_info"][config]["splits"])) | |
| try: | |
| data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/{split}/0000.parquet", columns=[column]) | |
| except pl.exceptions.ComputeError: | |
| try: | |
| data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column]) | |
| except Exception as error: | |
| yield f"β {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame() | |
| return | |
| texts = data[column].to_list() | |
| # batch_size = 100 | |
| predictions, texts_processed = [], [] | |
| num_examples = min(len(texts), num_examples) | |
| for i in range(0, num_examples, batch_size): | |
| batch_texts = texts[i:i+batch_size] | |
| batch_predictions = predict(batch_texts) | |
| predictions.extend(batch_predictions) | |
| texts_processed.extend(batch_texts) | |
| yield {"check in progress...": (i+batch_size) / num_examples}, *plot_and_df(texts_processed, predictions) | |
| yield {"finished": 1.}, *plot_and_df(texts_processed, predictions) | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # π« Dataset Quality Checker π« | |
| Use [nvidia/quality-classifier-deberta](https://huggingface.co/nvidia/quality-classifier-deberta) on any text dataset on the Hub. | |
| """ | |
| ) | |
| dataset_name = HuggingfaceHubSearch( | |
| label="Hub Dataset ID", | |
| placeholder="Search for dataset id on Huggingface", | |
| search_type="dataset", | |
| # value="fka/awesome-chatgpt-prompts", | |
| ) | |
| # config_name = "default" # TODO: user input | |
| with gr.Accordion("Dataset preview", open=False): | |
| def embed(name): | |
| html_code = f""" | |
| <iframe | |
| src="https://huggingface.co/datasets/{name}/embed/viewer/default/train" | |
| frameborder="0" | |
| width="100%" | |
| height="700px" | |
| ></iframe> | |
| """ | |
| return gr.HTML(value=html_code) | |
| text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)") | |
| batch_size = gr.Slider(0, 128, 32, step=8, label="Inference batch size (set this to smaller value if this space crashes.)") | |
| num_examples = gr.Number(500, label="Number of first examples to check") | |
| gr_check_btn = gr.Button("Check Dataset") | |
| progress_bar = gr.Label(show_label=False) | |
| plot = gr.BarPlot() | |
| with gr.Accordion("Explore some individual examples for each class", open=False): | |
| gr.Markdown("### Low") | |
| df_low = gr.DataFrame() | |
| gr.Markdown("### Medium") | |
| df_medium = gr.DataFrame() | |
| gr.Markdown("### High") | |
| df_high = gr.DataFrame() | |
| gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[progress_bar, plot, df_low, df_medium, df_high]) | |
| demo.launch() |