| | import logging |
| | import os |
| | import random |
| | from datetime import timedelta |
| | from statistics import mean |
| | from typing import Annotated, Any, Iterator, Union |
| |
|
| | import fasttext |
| | from cashews import cache |
| | from dotenv import load_dotenv |
| | from fastapi import FastAPI, Path, Query |
| | from httpx import AsyncClient, Client, Timeout |
| | from huggingface_hub import hf_hub_download |
| | from iso639 import Lang |
| | from starlette.responses import RedirectResponse |
| | from toolz import concat, groupby, valmap |
| |
|
| | cache.setup("mem://") |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| | app = FastAPI() |
| | load_dotenv() |
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| | assert HF_TOKEN |
| | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
| |
|
| | FASTTEXT_PREFIX_LENGTH = 9 |
| |
|
| | BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" |
| | DEFAULT_FAST_TEXT_MODEL = "facebook/fasttext-language-identification" |
| | headers = {"Authorization": f"Bearer {HF_TOKEN}"} |
| |
|
| | timeout = Timeout(60, read=120) |
| | client = Client(headers=headers, timeout=timeout) |
| | async_client = AsyncClient(headers=headers, timeout=timeout) |
| |
|
| | TARGET_COLUMN_NAMES = { |
| | "text", |
| | "input", |
| | "tokens", |
| | "prompt", |
| | "instruction", |
| | "sentence_1", |
| | "question", |
| | "sentence2", |
| | "answer", |
| | "sentence", |
| | "response", |
| | "context", |
| | "query", |
| | "chosen", |
| | "rejected", |
| | "question" |
| | } |
| |
|
| |
|
| | def datasets_server_valid_rows(hub_id: str): |
| | try: |
| | resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}") |
| | data = resp.json() |
| | return True if data.get("viewer") else bool(data.get("preview")) |
| | except Exception as e: |
| | logger.error(f"Failed to get is-valid for {hub_id}: {e}") |
| | return False |
| |
|
| |
|
| | async def get_first_config_and_split_name(hub_id: str): |
| | try: |
| | resp = await async_client.get( |
| | f"https://datasets-server.huggingface.co/splits?dataset={hub_id}" |
| | ) |
| |
|
| | data = resp.json() |
| | return data["splits"][0]["config"], data["splits"][0]["split"] |
| | except Exception as e: |
| | logger.error(f"Failed to get splits for {hub_id}: {e}") |
| | return (None, None) |
| |
|
| |
|
| | async def get_dataset_info(hub_id: str, config: str | None = None): |
| | if config is None: |
| | config_tuple, _ = await get_first_config_and_split_name(hub_id) |
| | if config_tuple is None: |
| | return None |
| | else: |
| | config = config_tuple |
| | resp = await async_client.get( |
| | f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}" |
| | ) |
| | resp.raise_for_status() |
| | return resp.json() |
| |
|
| |
|
| | @cache(ttl=timedelta(minutes=5)) |
| | async def fetch_rows(url: str) -> list[dict]: |
| | response = await async_client.get(url) |
| | if response.status_code == 200: |
| | data = response.json() |
| | return data.get("rows") |
| | else: |
| | print(f"Failed to fetch data: {response.status_code}") |
| | print(url) |
| | return [] |
| |
|
| |
|
| | |
| | async def get_random_rows( |
| | hub_id: str, |
| | total_length: int, |
| | number_of_rows: int, |
| | max_request_calls: int, |
| | config="default", |
| | split="train", |
| | ): |
| | rows = [] |
| | rows_per_call = min( |
| | number_of_rows // max_request_calls, total_length // max_request_calls |
| | ) |
| | rows_per_call = min(rows_per_call, 100) |
| | for _ in range(min(max_request_calls, number_of_rows // rows_per_call)): |
| | offset = random.randint(0, total_length - rows_per_call) |
| | url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}" |
| | logger.info(f"Fetching {url}") |
| | batch_rows = await fetch_rows(url) |
| | rows.extend(batch_rows) |
| | if len(rows) >= number_of_rows: |
| | break |
| | return [row.get("row") for row in rows] |
| |
|
| |
|
| | def load_model(repo_id: str) -> fasttext.FastText._FastText: |
| | from pathlib import Path |
| |
|
| | Path("code/models").mkdir(parents=True, exist_ok=True) |
| | model_path = hf_hub_download( |
| | repo_id, |
| | "model.bin", |
| | |
| | |
| | |
| | ) |
| | return fasttext.load_model(model_path) |
| |
|
| |
|
| | model = load_model(DEFAULT_FAST_TEXT_MODEL) |
| |
|
| |
|
| | def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: |
| | for row in rows: |
| | if isinstance(row, str): |
| | |
| | line = row.split("\n") |
| | for line in line: |
| | if line: |
| | yield line |
| | elif isinstance(row, list): |
| | try: |
| | line = " ".join(row) |
| | if len(line) < min_length: |
| | continue |
| | else: |
| | yield line |
| | except TypeError: |
| | continue |
| |
|
| |
|
| | def model_predict(inputs: str, k=1) -> list[dict[str, float]]: |
| | predictions = model.predict(inputs, k=k) |
| | return [ |
| | {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} |
| | for label, prob in zip(predictions[0], predictions[1]) |
| | ] |
| |
|
| |
|
| | def get_label(x): |
| | return x.get("label") |
| |
|
| |
|
| | def get_mean_score(preds): |
| | return mean([pred.get("score") for pred in preds]) |
| |
|
| |
|
| | def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): |
| | """Filter a dict to include items whose value is above `threshold_percent`""" |
| | total = sum(counts_dict.values()) |
| | threshold = total * threshold_percent |
| | return {k for k, v in counts_dict.items() if v >= threshold} |
| |
|
| |
|
| | def try_parse_language(lang: str) -> str | None: |
| | try: |
| | split = lang.split("_") |
| | lang = split[0] |
| | lang = Lang(lang) |
| | return lang.pt1 |
| | except Exception as e: |
| | logger.error(f"Failed to parse language {lang}: {e}") |
| | return None |
| |
|
| |
|
| | def predict_rows( |
| | rows, target_column, language_threshold_percent=0.2, return_raw_predictions=False |
| | ): |
| | rows = (row.get(target_column) for row in rows) |
| | rows = (row for row in rows if row is not None) |
| | rows = list(yield_clean_rows(rows)) |
| | predictions = [model_predict(row) for row in rows] |
| | predictions = [pred for pred in predictions if pred is not None] |
| | predictions = list(concat(predictions)) |
| | predictions_by_lang = groupby(get_label, predictions) |
| | langues_counts = valmap(len, predictions_by_lang) |
| | keys_to_keep = filter_by_frequency( |
| | langues_counts, threshold_percent=language_threshold_percent |
| | ) |
| | filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} |
| | raw_model_prediction_summary = dict(valmap(get_mean_score, filtered_dict)) |
| | parsed_langs = { |
| | try_parse_language(k): v for k, v in raw_model_prediction_summary.items() |
| | } |
| | default_data = { |
| | "language_prediction_summary": parsed_langs, |
| | "raw_model_prediction_summary": raw_model_prediction_summary, |
| | "hub_id": "hub_id", |
| | "config": "config", |
| | } |
| | if return_raw_predictions: |
| | default_data["raw_predictions"] = predictions |
| | return default_data |
| |
|
| |
|
| | @app.get("/", include_in_schema=False) |
| | def root(): |
| | return RedirectResponse(url="/docs") |
| |
|
| |
|
| | @app.get("/predict_dataset_language/{hub_id:path}") |
| | @cache(ttl=timedelta(minutes=10)) |
| | async def predict_language( |
| | hub_id: Annotated[str, Path(title="The hub id of the dataset to predict")], |
| | config: str | None = None, |
| | split: str | None = None, |
| | max_request_calls: Annotated[ |
| | int, Query(title="Max number of requests to datasets server", gt=0, le=50) |
| | ] = 10, |
| | number_of_rows: int = 1000, |
| | language_threshold_percent: float = 0.2, |
| | ) -> dict[Any, Any] | None: |
| | is_valid = datasets_server_valid_rows(hub_id) |
| | if not is_valid: |
| | logger.error(f"Dataset {hub_id} is not accessible via the datasets server.") |
| | return None |
| | |
| | if not config and not split: |
| | config_tuple, split_tuple = await get_first_config_and_split_name(hub_id) |
| | if config_tuple is None: |
| | logger.error(f"Could not retrieve configuration for dataset {hub_id}") |
| | return None |
| | config, split = config_tuple, split_tuple |
| | elif not config: |
| | config_tuple, _ = await get_first_config_and_split_name(hub_id) |
| | if config_tuple is None: |
| | logger.error(f"Could not retrieve configuration for dataset {hub_id}") |
| | return None |
| | config = config_tuple |
| | elif not split: |
| | _, split_tuple = await get_first_config_and_split_name(hub_id) |
| | if split_tuple is None: |
| | logger.error(f"Could not retrieve split for dataset {hub_id}") |
| | return None |
| | split = split_tuple |
| | |
| | info = await get_dataset_info(hub_id, config) |
| | if info is None: |
| | logger.error(f"Dataset {hub_id} is not accessible via the datasets server.") |
| | return None |
| | |
| | if dataset_info := info.get("dataset_info"): |
| | total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples") |
| | features = dataset_info.get("features") |
| | |
| | |
| | column_names = set(features.keys()) |
| | logger.info(f"Column names: {column_names}") |
| | |
| | |
| | lowercase_to_original = {col.lower(): col for col in column_names} |
| | |
| | |
| | lowercase_column_names = set(lowercase_to_original.keys()) |
| | lowercase_target_columns = {col.lower() for col in TARGET_COLUMN_NAMES} |
| | |
| | if not lowercase_column_names.intersection(lowercase_target_columns): |
| | logger.error( |
| | f"Dataset {hub_id} {column_names} does not contain any of the target columns {TARGET_COLUMN_NAMES}" |
| | ) |
| | return None |
| | |
| | |
| | target_column = None |
| | for column in TARGET_COLUMN_NAMES: |
| | if column.lower() in lowercase_column_names: |
| | |
| | target_column = lowercase_to_original[column.lower()] |
| | logger.info(f"Using column {target_column} for language detection") |
| | break |
| | |
| | if target_column is None: |
| | logger.error(f"Could not find a suitable column for language detection") |
| | return None |
| | |
| | random_rows = await get_random_rows( |
| | hub_id, |
| | total_rows_for_split, |
| | number_of_rows, |
| | max_request_calls, |
| | config, |
| | split, |
| | ) |
| | |
| | logger.info(f"Predicting language for {len(random_rows)} rows") |
| | predictions = predict_rows( |
| | random_rows, |
| | target_column, |
| | language_threshold_percent=language_threshold_percent, |
| | ) |
| | predictions["hub_id"] = hub_id |
| | predictions["config"] = config |
| | predictions["split"] = split |
| | return predictions |
| | |
| | else: |
| | logger.error(f"No dataset_info available for {hub_id}") |
| | return None |