""" Асинхронный сбор данных с ru.wikipedia.org для fine-tuning. Ускоряет исходный [collect_data.py](collect_data.py) за счет конкурентной загрузки картинок, но оставляет API-запросы к Wikipedia достаточно бережными. Установка: pip install aiohttp tqdm Примеры: python collect_data_async.py python collect_data_async.py --max-total 10000 --max-depth 2 --resume """ from __future__ import annotations import argparse import asyncio import hashlib import json from pathlib import Path from typing import Any, AsyncIterator, TextIO from urllib.parse import unquote import aiohttp from tqdm import tqdm API_URL = "https://ru.wikipedia.org/w/api.php" HEADERS = { # Укажи свои контакты при желании; для Wikimedia лучше честный bot UA, а не браузерный. "User-Agent": "ML2HomeworkCollector/1.0 (educational project; contact: local-run)", "Accept-Encoding": "gzip, deflate", } REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=30) CATEGORIES = [ "Категория:Кошки (род)", "Категория:Породы собак", "Категория:Совообразные", "Категория:Попугаеобразные", "Категория:Ястребообразные", "Категория:Бабочки", "Категория:Жуки", "Категория:Пресноводные рыбы", "Категория:Акулы", "Категория:Съедобные грибы", "Категория:Ядовитые грибы", "Категория:Фрукты", "Категория:Овощи", "Категория:Ягоды", "Категория:Орехи", "Категория:Хвойные", "Категория:Цветковые растения", "Категория:Кактусовые", "Категория:Легковые автомобили", "Категория:Мотоциклы", "Категория:Вертолёты", "Категория:Самолёты", "Категория:Танки", "Категория:Струнные музыкальные инструменты", "Категория:Духовые музыкальные инструменты", "Категория:Ударные музыкальные инструменты", "Категория:Супы", "Категория:Салаты", "Категория:Пирожные", "Категория:Мосты России", "Категория:Мосты Европы", "Категория:Замки Европы", "Категория:Храмы России", "Категория:Небоскрёбы", "Категория:Маяки", "Категория:Вулканы", "Категория:Озёра России", "Категория:Водопады", "Категория:Холодное оружие", "Категория:Огнестрельное оружие", "Категория:Минералы", "Категория:Драгоценные камни", "Категория:Монеты", "Категория:Флаги государств", ] SKIP_IMAGE_EXTENSIONS = {".svg", ".gif", ".ogg", ".ogv", ".webm", ".pdf", ".djvu"} DATA_DIR = Path("data") IMAGES_DIR = DATA_DIR / "images" METADATA_FILE = DATA_DIR / "metadata.jsonl" CHECKPOINT_FILE = DATA_DIR / "checkpoint.json" class AsyncCollector: def __init__(self, max_total: int, max_depth: int, resume: bool): self.max_total = max_total self.max_depth = max_depth self.resume = resume self.collected: set[str] = set() self.session: aiohttp.ClientSession | None = None self.meta_f: TextIO | None = None self.pbar: tqdm | None = None # API лучше не долбить параллельно; ускорение в основном будет на картинках. self.api_sem = asyncio.Semaphore(1) self.img_sem = asyncio.Semaphore(8) async def init(self) -> None: IMAGES_DIR.mkdir(parents=True, exist_ok=True) connector = aiohttp.TCPConnector(limit=16) self.session = aiohttp.ClientSession( headers=HEADERS, connector=connector, timeout=REQUEST_TIMEOUT, ) if self.resume and CHECKPOINT_FILE.exists(): with open(CHECKPOINT_FILE, encoding="utf-8") as f: self.collected = set(json.load(f).get("collected_titles", [])) mode = "a" if self.resume and METADATA_FILE.exists() else "w" self.meta_f = open(METADATA_FILE, mode, encoding="utf-8") self.pbar = tqdm(total=self.max_total, initial=len(self.collected), desc="Collecting") async def close(self) -> None: if self.session is not None: await self.session.close() if self.meta_f is not None: self.meta_f.close() if self.pbar is not None: self.pbar.close() def save_checkpoint(self) -> None: with open(CHECKPOINT_FILE, "w", encoding="utf-8") as f: json.dump({"collected_titles": list(self.collected)}, f, ensure_ascii=False) async def api_query(self, **params: Any) -> dict[str, Any]: if self.session is None: raise RuntimeError("Session is not initialized") normalized_params: dict[str, str | int | float] = { "format": "json", "action": "query", } for key, value in params.items(): if isinstance(value, bool): normalized_params[key] = "1" if value else "0" elif isinstance(value, (str, int, float)): normalized_params[key] = value else: normalized_params[key] = str(value) async with self.api_sem: await asyncio.sleep(0.05) for attempt in range(4): try: async with self.session.get(API_URL, params=normalized_params) as resp: if resp.status in (403, 429): wait = int(resp.headers.get("Retry-After", 5 * (attempt + 1))) tqdm.write(f"API limited ({resp.status}), sleeping {wait}s") await asyncio.sleep(wait) continue resp.raise_for_status() return await resp.json() except Exception as e: if attempt == 3: tqdm.write(f"API Error: {e}") return {} await asyncio.sleep(1.5 * (attempt + 1)) return {} async def download_image(self, url: str, save_path: Path) -> bool: if self.session is None: raise RuntimeError("Session is not initialized") if save_path.exists(): return True async with self.img_sem: for attempt in range(3): try: async with self.session.get(url) as resp: if resp.status in (403, 429): wait = int(resp.headers.get("Retry-After", 3 * (attempt + 1))) await asyncio.sleep(wait) continue resp.raise_for_status() content = await resp.read() with open(save_path, "wb") as f: f.write(content) return True except Exception: if attempt == 2: return False await asyncio.sleep(1 + attempt) return False async def iter_category_pages(self, category: str, max_per_category: int) -> AsyncIterator[str]: visited_cats: set[str] = set() count = 0 async def _crawl(cat: str, depth: int) -> AsyncIterator[str]: nonlocal count if depth > self.max_depth or cat in visited_cats or count >= max_per_category: return visited_cats.add(cat) cmcontinue: str | None = None subcats: list[str] = [] while count < max_per_category: params: dict[str, Any] = { "list": "categorymembers", "cmtitle": cat, "cmlimit": 50, "cmtype": "page|subcat", } if cmcontinue: params["cmcontinue"] = cmcontinue data = await self.api_query(**params) members = data.get("query", {}).get("categorymembers", []) if not members and "error" in data: return for member in members: if count >= max_per_category: return if member.get("ns") == 0: title = member.get("title") if isinstance(title, str): count += 1 yield title elif member.get("ns") == 14: title = member.get("title") if isinstance(title, str): subcats.append(title) cmcontinue = data.get("continue", {}).get("cmcontinue") if not cmcontinue: break for subcat in subcats: if count >= max_per_category: return async for title in _crawl(subcat, depth + 1): yield title async for title in _crawl(category, 0): yield title async def process_batch(self, batch: list[str], category: str) -> int: data = await self.api_query( titles="|".join(batch), prop="extracts|pageimages", exintro=1, explaintext=1, exsectionformat="plain", piprop="thumbnail", pithumbsize=512, pilimit="max", ) pages = data.get("query", {}).get("pages", {}) tasks: list[asyncio.Task[bool]] = [] records: list[dict[str, str]] = [] for page_id, page in pages.items(): try: if int(page_id) < 0: continue except Exception: continue title = page.get("title", "") if not isinstance(title, str) or title in self.collected: continue extract = page.get("extract", "") thumb = page.get("thumbnail", {}) image_url = thumb.get("source", "") if isinstance(thumb, dict) else "" if not isinstance(extract, str) or len(extract.strip()) < 50: continue if not isinstance(image_url, str) or not image_url: continue ext = Path(unquote(image_url)).suffix.lower().split("?")[0] if ext in SKIP_IMAGE_EXTENSIONS: continue safe_name = hashlib.md5(title.encode("utf-8")).hexdigest()[:12] final_ext = ext if ext and len(ext) <= 5 else ".jpg" img_path = IMAGES_DIR / f"{safe_name}{final_ext}" records.append( { "title": title, "text": extract.strip(), "image_path": str(img_path), "image_url": image_url, "category": category, } ) tasks.append(asyncio.create_task(self.download_image(image_url, img_path))) if not tasks: return 0 results = await asyncio.gather(*tasks) if self.meta_f is None or self.pbar is None: raise RuntimeError("Output files are not initialized") added = 0 for record, success in zip(records, results): if not success: continue self.meta_f.write(json.dumps(record, ensure_ascii=False) + "\n") self.collected.add(record["title"]) added += 1 self.pbar.update(1) self.meta_f.flush() return added async def collect_from_category(self, category: str, limit: int) -> int: cat_count = 0 batch: list[str] = [] async for title in self.iter_category_pages(category, limit * 3): if cat_count >= limit or len(self.collected) >= self.max_total: break if title in self.collected: continue batch.append(title) if len(batch) >= 50: cat_count += await self.process_batch(batch, category) batch = [] if batch and cat_count < limit and len(self.collected) < self.max_total: cat_count += await self.process_batch(batch, category) return cat_count async def run(self) -> None: await self.init() try: base_per_cat = self.max_total // len(CATEGORIES) tqdm.write(f"Pass 1: up to {base_per_cat} per category ({len(CATEGORIES)} categories)") cat_stats: dict[str, int] = {} for category in CATEGORIES: if len(self.collected) >= self.max_total: break tqdm.write(f"\n📂 {category}") n = await self.collect_from_category(category, base_per_cat) cat_stats[category] = n tqdm.write(f" ✓ {n} pairs") self.save_checkpoint() remaining = self.max_total - len(self.collected) if remaining > 0: big_cats = sorted(cat_stats, key=lambda c: cat_stats[c], reverse=True) extra_per_cat = remaining // min(len(big_cats), 10) + 50 tqdm.write(f"\nPass 2: collecting {remaining} more from largest categories") for category in big_cats: if len(self.collected) >= self.max_total: break tqdm.write(f"\n📂 {category} (extra)") n = await self.collect_from_category(category, extra_per_cat) tqdm.write(f" ✓ {n} extra pairs") self.save_checkpoint() finally: await self.close() print(f"\nDone! Collected {len(self.collected)} pairs.") print(f"Images: {IMAGES_DIR}") print(f"Metadata: {METADATA_FILE}") def main() -> None: parser = argparse.ArgumentParser(description="Collect Wikipedia image-text pairs (async)") parser.add_argument("--max-total", type=int, default=10000, help="Total pairs to collect") parser.add_argument("--max-depth", type=int, default=2, help="Max category recursion depth") parser.add_argument("--resume", action="store_true", help="Resume from checkpoint") args = parser.parse_args() collector = AsyncCollector(args.max_total, args.max_depth, args.resume) asyncio.run(collector.run()) if __name__ == "__main__": main()