| | import gradio as gr |
| | from httpx import Client |
| | import random |
| | import os |
| | import fasttext |
| | from huggingface_hub import hf_hub_download |
| | from typing import Union |
| | from typing import Iterator |
| | from dotenv import load_dotenv |
| | from toolz import groupby, valmap, concat |
| | from statistics import mean |
| | from httpx import Timeout |
| | from huggingface_hub.utils import logging |
| |
|
| | logger = logging.get_logger(__name__) |
| | load_dotenv() |
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| |
|
| |
|
| | BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" |
| | DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID" |
| | headers = { |
| | "authorization": f"Bearer ${HF_TOKEN}", |
| | } |
| | timeout = Timeout(60, read=120) |
| | client = Client(headers=headers, timeout=timeout) |
| | |
| | |
| | TARGET_COLUMN_NAMES = { |
| | "text", |
| | "input", |
| | "tokens", |
| | "prompt", |
| | "instruction", |
| | "sentence_1", |
| | "question", |
| | "sentence2", |
| | "answer", |
| | "sentence", |
| | "response", |
| | "context", |
| | "query", |
| | } |
| |
|
| |
|
| | def datasets_server_valid_rows(hub_id: str): |
| | resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}") |
| | resp.raise_for_status() |
| | return resp.json()["viewer"] |
| |
|
| |
|
| | def get_first_config_and_split_name(hub_id: str): |
| | resp = client.get(f"https://datasets-server.huggingface.co/splits?dataset={hub_id}") |
| | resp.raise_for_status() |
| | data = resp.json() |
| | return data["splits"][0]["config"], data["splits"][0]["split"] |
| |
|
| |
|
| | def get_dataset_info(hub_id: str, config: str | None = None): |
| | if config is None: |
| | config = get_first_config_and_split_name(hub_id) |
| | if config is None: |
| | return None |
| | else: |
| | config = config[0] |
| | resp = client.get( |
| | f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}" |
| | ) |
| | resp.raise_for_status() |
| | return resp.json() |
| |
|
| |
|
| | def get_random_rows( |
| | hub_id, |
| | total_length, |
| | number_of_rows, |
| | max_request_calls, |
| | config="default", |
| | split="train", |
| | ): |
| | rows = [] |
| | rows_per_call = min( |
| | number_of_rows // max_request_calls, total_length // max_request_calls |
| | ) |
| | rows_per_call = min(rows_per_call, 100) |
| | for _ in range(min(max_request_calls, number_of_rows // rows_per_call)): |
| | offset = random.randint(0, total_length - rows_per_call) |
| | url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}" |
| | response = client.get(url) |
| |
|
| | if response.status_code == 200: |
| | data = response.json() |
| | batch_rows = data.get("rows") |
| | rows.extend(batch_rows) |
| | else: |
| | print(f"Failed to fetch data: {response.status_code}") |
| | print(url) |
| | if len(rows) >= number_of_rows: |
| | break |
| | return [row.get("row") for row in rows] |
| |
|
| |
|
| | def load_model(repo_id: str) -> fasttext.FastText._FastText: |
| | model_path = hf_hub_download(repo_id, filename="model.bin") |
| | return fasttext.load_model(model_path) |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: |
| | for row in rows: |
| | if isinstance(row, str): |
| | |
| | line = row.split("\n") |
| | for line in line: |
| | if line: |
| | yield line |
| | elif isinstance(row, list): |
| | try: |
| | line = " ".join(row) |
| | if len(line) < min_length: |
| | continue |
| | else: |
| | yield line |
| | except TypeError: |
| | continue |
| |
|
| |
|
| | FASTTEXT_PREFIX_LENGTH = 9 |
| |
|
| | |
| |
|
| | model = fasttext.load_model( |
| | hf_hub_download("facebook/fasttext-language-identification", "model.bin") |
| | ) |
| |
|
| |
|
| | def model_predict(inputs: str, k=1) -> list[dict[str, float]]: |
| | predictions = model.predict(inputs, k=k) |
| | return [ |
| | {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} |
| | for label, prob in zip(predictions[0], predictions[1]) |
| | ] |
| |
|
| |
|
| | def get_label(x): |
| | return x.get("label") |
| |
|
| |
|
| | def get_mean_score(preds): |
| | return mean([pred.get("score") for pred in preds]) |
| |
|
| |
|
| | def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): |
| | """Filter a dict to include items whose value is above `threshold_percent`""" |
| | total = sum(counts_dict.values()) |
| | threshold = total * threshold_percent |
| | return {k for k, v in counts_dict.items() if v >= threshold} |
| |
|
| |
|
| | def predict_rows(rows, target_column, language_threshold_percent=0.2): |
| | rows = (row.get(target_column) for row in rows) |
| | rows = (row for row in rows if row is not None) |
| | rows = list(yield_clean_rows(rows)) |
| | predictions = [model_predict(row) for row in rows] |
| | predictions = [pred for pred in predictions if pred is not None] |
| | predictions = list(concat(predictions)) |
| | predictions_by_lang = groupby(get_label, predictions) |
| | langues_counts = valmap(len, predictions_by_lang) |
| | keys_to_keep = filter_by_frequency( |
| | langues_counts, threshold_percent=language_threshold_percent |
| | ) |
| | filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} |
| | return { |
| | "predictions": dict(valmap(get_mean_score, filtered_dict)), |
| | "pred": predictions, |
| | } |
| |
|
| |
|
| | def predict_language( |
| | hub_id: str, |
| | config: str | None = None, |
| | split: str | None = None, |
| | max_request_calls: int = 10, |
| | ): |
| | is_valid = datasets_server_valid_rows(hub_id) |
| | if not is_valid: |
| | gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") |
| | if not config: |
| | config, split = get_first_config_and_split_name(hub_id) |
| | info = get_dataset_info(hub_id, config) |
| | if info is None: |
| | gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") |
| | if dataset_info := info.get("dataset_info"): |
| | total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples") |
| | features = dataset_info.get("features") |
| | column_names = set(features.keys()) |
| | logger.info(f"Column names: {column_names}") |
| | if not set(column_names).intersection(TARGET_COLUMN_NAMES): |
| | raise gr.Error( |
| | f"Dataset {hub_id} does not contain any of the target columns {TARGET_COLUMN_NAMES}" |
| | ) |
| | for column in TARGET_COLUMN_NAMES: |
| | if column in column_names: |
| | target_column = column |
| | logger.info(f"Using column {target_column} for language detection") |
| | break |
| | random_rows = get_random_rows( |
| | hub_id, total_rows_for_split, 1000, max_request_calls, config, split |
| | ) |
| | logger.info(f"Predicting language for {len(random_rows)} rows") |
| | return predict_rows(random_rows, target_column) |
| |
|
| |
|
| | interface = gr.Interface(predict_language, inputs="text", outputs="json") |
| | interface.queue() |
| | interface.launch() |
| |
|