File size: 15,243 Bytes
ece157f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
"""
Асинхронный сбор данных с ru.wikipedia.org для fine-tuning.
Ускоряет исходный [collect_data.py](collect_data.py) за счет конкурентной загрузки картинок,
но оставляет API-запросы к Wikipedia достаточно бережными.

Установка:
    pip install aiohttp tqdm

Примеры:
    python collect_data_async.py
    python collect_data_async.py --max-total 10000 --max-depth 2 --resume
"""

from __future__ import annotations

import argparse
import asyncio
import hashlib
import json
from pathlib import Path
from typing import Any, AsyncIterator, TextIO
from urllib.parse import unquote

import aiohttp
from tqdm import tqdm

API_URL = "https://ru.wikipedia.org/w/api.php"
HEADERS = {
    # Укажи свои контакты при желании; для Wikimedia лучше честный bot UA, а не браузерный.
    "User-Agent": "ML2HomeworkCollector/1.0 (educational project; contact: local-run)",
    "Accept-Encoding": "gzip, deflate",
}
REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=30)

CATEGORIES = [
    "Категория:Кошки (род)",
    "Категория:Породы собак",
    "Категория:Совообразные",
    "Категория:Попугаеобразные",
    "Категория:Ястребообразные",
    "Категория:Бабочки",
    "Категория:Жуки",
    "Категория:Пресноводные рыбы",
    "Категория:Акулы",
    "Категория:Съедобные грибы",
    "Категория:Ядовитые грибы",
    "Категория:Фрукты",
    "Категория:Овощи",
    "Категория:Ягоды",
    "Категория:Орехи",
    "Категория:Хвойные",
    "Категория:Цветковые растения",
    "Категория:Кактусовые",
    "Категория:Легковые автомобили",
    "Категория:Мотоциклы",
    "Категория:Вертолёты",
    "Категория:Самолёты",
    "Категория:Танки",
    "Категория:Струнные музыкальные инструменты",
    "Категория:Духовые музыкальные инструменты",
    "Категория:Ударные музыкальные инструменты",
    "Категория:Супы",
    "Категория:Салаты",
    "Категория:Пирожные",
    "Категория:Мосты России",
    "Категория:Мосты Европы",
    "Категория:Замки Европы",
    "Категория:Храмы России",
    "Категория:Небоскрёбы",
    "Категория:Маяки",
    "Категория:Вулканы",
    "Категория:Озёра России",
    "Категория:Водопады",
    "Категория:Холодное оружие",
    "Категория:Огнестрельное оружие",
    "Категория:Минералы",
    "Категория:Драгоценные камни",
    "Категория:Монеты",
    "Категория:Флаги государств",
]

SKIP_IMAGE_EXTENSIONS = {".svg", ".gif", ".ogg", ".ogv", ".webm", ".pdf", ".djvu"}

DATA_DIR = Path("data")
IMAGES_DIR = DATA_DIR / "images"
METADATA_FILE = DATA_DIR / "metadata.jsonl"
CHECKPOINT_FILE = DATA_DIR / "checkpoint.json"


class AsyncCollector:
    def __init__(self, max_total: int, max_depth: int, resume: bool):
        self.max_total = max_total
        self.max_depth = max_depth
        self.resume = resume
        self.collected: set[str] = set()
        self.session: aiohttp.ClientSession | None = None
        self.meta_f: TextIO | None = None
        self.pbar: tqdm | None = None

        # API лучше не долбить параллельно; ускорение в основном будет на картинках.
        self.api_sem = asyncio.Semaphore(1)
        self.img_sem = asyncio.Semaphore(8)

    async def init(self) -> None:
        IMAGES_DIR.mkdir(parents=True, exist_ok=True)
        connector = aiohttp.TCPConnector(limit=16)
        self.session = aiohttp.ClientSession(
            headers=HEADERS,
            connector=connector,
            timeout=REQUEST_TIMEOUT,
        )

        if self.resume and CHECKPOINT_FILE.exists():
            with open(CHECKPOINT_FILE, encoding="utf-8") as f:
                self.collected = set(json.load(f).get("collected_titles", []))

        mode = "a" if self.resume and METADATA_FILE.exists() else "w"
        self.meta_f = open(METADATA_FILE, mode, encoding="utf-8")
        self.pbar = tqdm(total=self.max_total, initial=len(self.collected), desc="Collecting")

    async def close(self) -> None:
        if self.session is not None:
            await self.session.close()
        if self.meta_f is not None:
            self.meta_f.close()
        if self.pbar is not None:
            self.pbar.close()

    def save_checkpoint(self) -> None:
        with open(CHECKPOINT_FILE, "w", encoding="utf-8") as f:
            json.dump({"collected_titles": list(self.collected)}, f, ensure_ascii=False)

    async def api_query(self, **params: Any) -> dict[str, Any]:
        if self.session is None:
            raise RuntimeError("Session is not initialized")

        normalized_params: dict[str, str | int | float] = {
            "format": "json",
            "action": "query",
        }
        for key, value in params.items():
            if isinstance(value, bool):
                normalized_params[key] = "1" if value else "0"
            elif isinstance(value, (str, int, float)):
                normalized_params[key] = value
            else:
                normalized_params[key] = str(value)

        async with self.api_sem:
            await asyncio.sleep(0.05)
            for attempt in range(4):
                try:
                    async with self.session.get(API_URL, params=normalized_params) as resp:
                        if resp.status in (403, 429):
                            wait = int(resp.headers.get("Retry-After", 5 * (attempt + 1)))
                            tqdm.write(f"API limited ({resp.status}), sleeping {wait}s")
                            await asyncio.sleep(wait)
                            continue
                        resp.raise_for_status()
                        return await resp.json()
                except Exception as e:
                    if attempt == 3:
                        tqdm.write(f"API Error: {e}")
                        return {}
                    await asyncio.sleep(1.5 * (attempt + 1))
        return {}

    async def download_image(self, url: str, save_path: Path) -> bool:
        if self.session is None:
            raise RuntimeError("Session is not initialized")
        if save_path.exists():
            return True

        async with self.img_sem:
            for attempt in range(3):
                try:
                    async with self.session.get(url) as resp:
                        if resp.status in (403, 429):
                            wait = int(resp.headers.get("Retry-After", 3 * (attempt + 1)))
                            await asyncio.sleep(wait)
                            continue
                        resp.raise_for_status()
                        content = await resp.read()
                        with open(save_path, "wb") as f:
                            f.write(content)
                        return True
                except Exception:
                    if attempt == 2:
                        return False
                    await asyncio.sleep(1 + attempt)
        return False

    async def iter_category_pages(self, category: str, max_per_category: int) -> AsyncIterator[str]:
        visited_cats: set[str] = set()
        count = 0

        async def _crawl(cat: str, depth: int) -> AsyncIterator[str]:
            nonlocal count
            if depth > self.max_depth or cat in visited_cats or count >= max_per_category:
                return
            visited_cats.add(cat)

            cmcontinue: str | None = None
            subcats: list[str] = []

            while count < max_per_category:
                params: dict[str, Any] = {
                    "list": "categorymembers",
                    "cmtitle": cat,
                    "cmlimit": 50,
                    "cmtype": "page|subcat",
                }
                if cmcontinue:
                    params["cmcontinue"] = cmcontinue

                data = await self.api_query(**params)
                members = data.get("query", {}).get("categorymembers", [])
                if not members and "error" in data:
                    return

                for member in members:
                    if count >= max_per_category:
                        return
                    if member.get("ns") == 0:
                        title = member.get("title")
                        if isinstance(title, str):
                            count += 1
                            yield title
                    elif member.get("ns") == 14:
                        title = member.get("title")
                        if isinstance(title, str):
                            subcats.append(title)

                cmcontinue = data.get("continue", {}).get("cmcontinue")
                if not cmcontinue:
                    break

            for subcat in subcats:
                if count >= max_per_category:
                    return
                async for title in _crawl(subcat, depth + 1):
                    yield title

        async for title in _crawl(category, 0):
            yield title

    async def process_batch(self, batch: list[str], category: str) -> int:
        data = await self.api_query(
            titles="|".join(batch),
            prop="extracts|pageimages",
            exintro=1,
            explaintext=1,
            exsectionformat="plain",
            piprop="thumbnail",
            pithumbsize=512,
            pilimit="max",
        )
        pages = data.get("query", {}).get("pages", {})

        tasks: list[asyncio.Task[bool]] = []
        records: list[dict[str, str]] = []

        for page_id, page in pages.items():
            try:
                if int(page_id) < 0:
                    continue
            except Exception:
                continue

            title = page.get("title", "")
            if not isinstance(title, str) or title in self.collected:
                continue

            extract = page.get("extract", "")
            thumb = page.get("thumbnail", {})
            image_url = thumb.get("source", "") if isinstance(thumb, dict) else ""

            if not isinstance(extract, str) or len(extract.strip()) < 50:
                continue
            if not isinstance(image_url, str) or not image_url:
                continue

            ext = Path(unquote(image_url)).suffix.lower().split("?")[0]
            if ext in SKIP_IMAGE_EXTENSIONS:
                continue

            safe_name = hashlib.md5(title.encode("utf-8")).hexdigest()[:12]
            final_ext = ext if ext and len(ext) <= 5 else ".jpg"
            img_path = IMAGES_DIR / f"{safe_name}{final_ext}"

            records.append(
                {
                    "title": title,
                    "text": extract.strip(),
                    "image_path": str(img_path),
                    "image_url": image_url,
                    "category": category,
                }
            )
            tasks.append(asyncio.create_task(self.download_image(image_url, img_path)))

        if not tasks:
            return 0

        results = await asyncio.gather(*tasks)

        if self.meta_f is None or self.pbar is None:
            raise RuntimeError("Output files are not initialized")

        added = 0
        for record, success in zip(records, results):
            if not success:
                continue
            self.meta_f.write(json.dumps(record, ensure_ascii=False) + "\n")
            self.collected.add(record["title"])
            added += 1
            self.pbar.update(1)

        self.meta_f.flush()
        return added

    async def collect_from_category(self, category: str, limit: int) -> int:
        cat_count = 0
        batch: list[str] = []

        async for title in self.iter_category_pages(category, limit * 3):
            if cat_count >= limit or len(self.collected) >= self.max_total:
                break
            if title in self.collected:
                continue

            batch.append(title)
            if len(batch) >= 50:
                cat_count += await self.process_batch(batch, category)
                batch = []

        if batch and cat_count < limit and len(self.collected) < self.max_total:
            cat_count += await self.process_batch(batch, category)

        return cat_count

    async def run(self) -> None:
        await self.init()
        try:
            base_per_cat = self.max_total // len(CATEGORIES)
            tqdm.write(f"Pass 1: up to {base_per_cat} per category ({len(CATEGORIES)} categories)")

            cat_stats: dict[str, int] = {}
            for category in CATEGORIES:
                if len(self.collected) >= self.max_total:
                    break
                tqdm.write(f"\n📂 {category}")
                n = await self.collect_from_category(category, base_per_cat)
                cat_stats[category] = n
                tqdm.write(f"   ✓ {n} pairs")
                self.save_checkpoint()

            remaining = self.max_total - len(self.collected)
            if remaining > 0:
                big_cats = sorted(cat_stats, key=lambda c: cat_stats[c], reverse=True)
                extra_per_cat = remaining // min(len(big_cats), 10) + 50

                tqdm.write(f"\nPass 2: collecting {remaining} more from largest categories")
                for category in big_cats:
                    if len(self.collected) >= self.max_total:
                        break
                    tqdm.write(f"\n📂 {category} (extra)")
                    n = await self.collect_from_category(category, extra_per_cat)
                    tqdm.write(f"   ✓ {n} extra pairs")
                    self.save_checkpoint()
        finally:
            await self.close()

        print(f"\nDone! Collected {len(self.collected)} pairs.")
        print(f"Images: {IMAGES_DIR}")
        print(f"Metadata: {METADATA_FILE}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Collect Wikipedia image-text pairs (async)")
    parser.add_argument("--max-total", type=int, default=10000, help="Total pairs to collect")
    parser.add_argument("--max-depth", type=int, default=2, help="Max category recursion depth")
    parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
    args = parser.parse_args()

    collector = AsyncCollector(args.max_total, args.max_depth, args.resume)
    asyncio.run(collector.run())


if __name__ == "__main__":
    main()