| import argparse |
| import json |
| import logging |
| import os |
| import time |
| import urllib.error |
| import urllib.request |
| from typing import List |
|
|
| import pandas as pd |
| from tqdm import tqdm |
|
|
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def split_and_save_datasets( |
| lines: List[str], output_dir: str, train_proportion: float, valid_proportion: float |
| ): |
| total_lines = len(lines) |
| train_lines = lines[: int(total_lines * train_proportion)] |
| valid_lines = lines[ |
| int(total_lines * train_proportion) : int( |
| total_lines * (train_proportion + valid_proportion) |
| ) |
| ] |
| test_lines = lines[int(total_lines * (train_proportion + valid_proportion)) :] |
|
|
| with open(f"{output_dir}/train_dataset.json", "w") as f: |
| f.write("\n".join(train_lines)) |
|
|
| with open(f"{output_dir}/valid_dataset.json", "w") as f: |
| f.write("\n".join(valid_lines)) |
|
|
| with open(f"{output_dir}/test_dataset.json", "w") as f: |
| f.write("\n".join(test_lines)) |
|
|
|
|
| def prepare_wit( |
| tsv: str, |
| language: str, |
| output_dir: str, |
| seed: int, |
| train_proportion: float, |
| valid_proportion: float, |
| backup_period: int, |
| language_col: str = "language", |
| caption_col: str = "caption_reference_description", |
| url_col: str = "image_url", |
| pause=0.875, |
| retries: int = 10, |
| ): |
| os.makedirs(output_dir, exist_ok=True) |
| logger.info("Loading dataset") |
| df = pd.read_csv(tsv, sep="\t", engine="python") |
| existing_files = set(os.listdir(output_dir)) |
| not_exists_condition = ~( |
| df[url_col].map(lambda x: x.split("/")[-1][-100:]).isin(existing_files) |
| ) |
| df = df[ |
| (df["language"] == language) |
| & (~df["caption_reference_description"].isnull()) |
| & not_exists_condition |
| ] |
| |
| df = df.sample(frac=1.0, random_state=seed) |
| logger.info(f"Trying to downloading {df.shape[0]} files") |
| lines = [] |
| count = 0 |
| try: |
| with tqdm(total=len(df)) as pbar: |
| for i, row in tqdm(df.iterrows()): |
| url = row[url_col] |
| caption = row[caption_col] |
| |
| image_filename = url.split("/")[-1][-100:] |
| image_path = f"{output_dir}/{image_filename}" |
| for retry in range(retries): |
| try: |
| |
| urllib.request.urlretrieve(url, image_path) |
| lines.append( |
| json.dumps( |
| {"image_path": image_path, "captions": [caption]}, |
| ensure_ascii=False, |
| ) |
| ) |
| count += 1 |
| break |
| except urllib.error.HTTPError: |
| time.sleep(pause * 10) |
| if count % backup_period == 0: |
| logger.info(f"Saving dataset backup: Number of lines {len(lines)}") |
| split_and_save_datasets( |
| lines, output_dir, train_proportion, valid_proportion |
| ) |
| if retry == retries - 1: |
| logger.info(f"Skipping {image_filename}") |
| pbar.update(1) |
| |
| finally: |
| split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Download and prepare the WIT dataset") |
| parser.add_argument( |
| "--tsv", |
| type=str, |
| default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv", |
| ) |
| parser.add_argument("--language", type=str, default="es") |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset", |
| ) |
| parser.add_argument("--random_seed", type=int, default=0) |
| parser.add_argument("--train_proportion", type=float, default=0.8) |
| parser.add_argument("--valid_proportion", type=float, default=0.1) |
| parser.add_argument("--backup_period", type=int, default=1000) |
|
|
| args = parser.parse_args() |
| assert ( |
| args.train_proportion + args.valid_proportion < 1.0 |
| ), "The sum of train_proportion and valid_proportion has to be < 1.0" |
| prepare_wit( |
| args.tsv, |
| args.language, |
| args.output_dir, |
| args.random_seed, |
| args.train_proportion, |
| args.valid_proportion, |
| args.backup_period, |
| ) |
|
|