letitbE commited on
Commit
ece157f
·
0 Parent(s):

Add data collection scripts and requirements

Browse files
Files changed (5) hide show
  1. .gitattributes +4 -0
  2. .gitignore +7 -0
  3. collect_data.py +211 -0
  4. collect_data_async.py +391 -0
  5. requirements.txt +8 -0
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ *.png filter=lfs diff=lfs merge=lfs -text
3
+ *.jpg filter=lfs diff=lfs merge=lfs -text
4
+ *.JPG filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.bin
2
+ data.zip
3
+ data/
4
+ hf_cache/
5
+ __pycache__/
6
+ *.safetensors
7
+ app/static/uploads/
collect_data.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Сбор данных с ru.wikipedia.org для fine-tuning CLIP.
3
+ Собирает случайные статьи с изображениями — пары (картинка, текст).
4
+
5
+ Использование:
6
+ python collect_data.py
7
+ python collect_data.py --max-total 10000
8
+ python collect_data.py --max-total 10000 --resume
9
+ """
10
+
11
+ import argparse
12
+ import hashlib
13
+ import json
14
+ import time
15
+ from pathlib import Path
16
+ from urllib.parse import unquote
17
+
18
+ import requests
19
+ from tqdm import tqdm
20
+
21
+ API_URL = "https://ru.wikipedia.org/w/api.php"
22
+ SESSION = requests.Session()
23
+ SESSION.headers.update({
24
+ "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
25
+ "Referer": "https://ru.wikipedia.org/",
26
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8",
27
+ })
28
+
29
+ SKIP_IMAGE_EXTENSIONS = {".svg", ".gif", ".ogg", ".ogv", ".webm", ".pdf", ".djvu"}
30
+
31
+ DATA_DIR = Path("data")
32
+ IMAGES_DIR = DATA_DIR / "images"
33
+ METADATA_FILE = DATA_DIR / "metadata.jsonl"
34
+ CHECKPOINT_FILE = DATA_DIR / "checkpoint.json"
35
+
36
+
37
+ def api_query(**params):
38
+ """Запрос к MediaWiki API с rate limiting."""
39
+ params.setdefault("format", "json")
40
+ params.setdefault("action", "query")
41
+ time.sleep(0.1)
42
+ resp = SESSION.get(API_URL, params=params, timeout=30)
43
+ resp.raise_for_status()
44
+ return resp.json()
45
+
46
+
47
+ def get_random_titles(count: int = 20) -> list[str]:
48
+ """Получить случайные заголовки статей (namespace 0 = основные статьи)."""
49
+ data = api_query(list="random", rnnamespace=0, rnlimit=count)
50
+ return [p["title"] for p in data.get("query", {}).get("random", [])]
51
+
52
+
53
+ def get_article_data(titles: list[str]) -> dict:
54
+ """Получить extract + thumbnail для пачки статей (до 20)."""
55
+ data = api_query(
56
+ titles="|".join(titles),
57
+ prop="extracts|pageimages",
58
+ exintro=True,
59
+ explaintext=True,
60
+ exsectionformat="plain",
61
+ piprop="thumbnail",
62
+ pithumbsize=512,
63
+ pilimit="max",
64
+ )
65
+ pages = data.get("query", {}).get("pages", {})
66
+ results = {}
67
+ for page_id, page in pages.items():
68
+ if int(page_id) < 0:
69
+ continue
70
+ title = page.get("title", "")
71
+ extract = page.get("extract", "").strip()
72
+ thumb = page.get("thumbnail", {})
73
+ image_url = thumb.get("source", "")
74
+ results[title] = {"extract": extract, "image_url": image_url}
75
+ return results
76
+
77
+
78
+ def download_image(url: str, save_path: Path, max_retries: int = 3) -> bool:
79
+ """Скачать изображение с retry и exponential backoff."""
80
+ for attempt in range(max_retries):
81
+ try:
82
+ time.sleep(0.2 + attempt * 2)
83
+ resp = SESSION.get(url, timeout=30, stream=True)
84
+ if resp.status_code == 429:
85
+ wait = int(resp.headers.get("Retry-After", 5 * (attempt + 1)))
86
+ tqdm.write(f" ⏳ Rate limited, waiting {wait}s...")
87
+ time.sleep(wait)
88
+ continue
89
+ resp.raise_for_status()
90
+ with open(save_path, "wb") as f:
91
+ for chunk in resp.iter_content(8192):
92
+ f.write(chunk)
93
+ return True
94
+ except requests.exceptions.HTTPError as e:
95
+ if "429" in str(e) and attempt < max_retries - 1:
96
+ time.sleep(5 * (attempt + 1))
97
+ continue
98
+ tqdm.write(f" ⚠ Download failed: {e}")
99
+ return False
100
+ except Exception as e:
101
+ tqdm.write(f" ⚠ Download failed: {e}")
102
+ return False
103
+ return False
104
+
105
+
106
+ def image_filename(title: str, url: str) -> str:
107
+ ext = Path(unquote(url)).suffix.lower().split("?")[0]
108
+ if not ext or len(ext) > 5:
109
+ ext = ".jpg"
110
+ safe_name = hashlib.md5(title.encode()).hexdigest()[:12]
111
+ return f"{safe_name}{ext}"
112
+
113
+
114
+ def load_checkpoint() -> set[str]:
115
+ if CHECKPOINT_FILE.exists():
116
+ with open(CHECKPOINT_FILE) as f:
117
+ return set(json.load(f).get("collected_titles", []))
118
+ return set()
119
+
120
+
121
+ def save_checkpoint(collected: set[str]):
122
+ with open(CHECKPOINT_FILE, "w") as f:
123
+ json.dump({"collected_titles": list(collected)}, f, ensure_ascii=False)
124
+
125
+
126
+ def main():
127
+ parser = argparse.ArgumentParser(description="Collect random Wikipedia image-text pairs")
128
+ parser.add_argument("--max-total", type=int, default=10000, help="Total pairs to collect")
129
+ parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
130
+ args = parser.parse_args()
131
+
132
+ IMAGES_DIR.mkdir(parents=True, exist_ok=True)
133
+
134
+ collected = load_checkpoint() if args.resume else set()
135
+ mode = "a" if args.resume and METADATA_FILE.exists() else "w"
136
+
137
+ total = len(collected)
138
+ skipped = 0
139
+ pbar = tqdm(total=args.max_total, initial=total, desc="Collecting")
140
+
141
+ with open(METADATA_FILE, mode, encoding="utf-8") as meta_f:
142
+ while total < args.max_total:
143
+ # Берём пачку случайных статей
144
+ random_titles = get_random_titles(20)
145
+ # Фильтруем уже собранные
146
+ new_titles = [t for t in random_titles if t not in collected]
147
+ if not new_titles:
148
+ continue
149
+
150
+ # Получаем данные статей
151
+ article_data = get_article_data(new_titles)
152
+
153
+ for title, info in article_data.items():
154
+ if total >= args.max_total:
155
+ break
156
+ if title in collected:
157
+ continue
158
+
159
+ extract = info["extract"]
160
+ image_url = info["image_url"]
161
+
162
+ # Пропуск статей без текста или картинки
163
+ if not extract or len(extract) < 50:
164
+ skipped += 1
165
+ continue
166
+ if not image_url:
167
+ skipped += 1
168
+ continue
169
+
170
+ # Пропуск не-фото форматов
171
+ ext = Path(unquote(image_url)).suffix.lower().split("?")[0]
172
+ if ext in SKIP_IMAGE_EXTENSIONS:
173
+ skipped += 1
174
+ continue
175
+
176
+ # Скачиваем
177
+ fname = image_filename(title, image_url)
178
+ img_path = IMAGES_DIR / fname
179
+
180
+ if not img_path.exists():
181
+ if not download_image(image_url, img_path):
182
+ skipped += 1
183
+ continue
184
+
185
+ record = {
186
+ "title": title,
187
+ "text": extract,
188
+ "image_path": str(img_path),
189
+ "image_url": image_url,
190
+ }
191
+ meta_f.write(json.dumps(record, ensure_ascii=False) + "\n")
192
+ meta_f.flush()
193
+
194
+ collected.add(title)
195
+ total += 1
196
+ pbar.update(1)
197
+
198
+ # Checkpoint каждые 100 статей
199
+ if total % 100 < 20:
200
+ save_checkpoint(collected)
201
+ pbar.set_postfix(skipped=skipped)
202
+
203
+ save_checkpoint(collected)
204
+ pbar.close()
205
+ print(f"\nDone! Collected {total} pairs (skipped {skipped} without image/text).")
206
+ print(f"Images: {IMAGES_DIR}")
207
+ print(f"Metadata: {METADATA_FILE}")
208
+
209
+
210
+ if __name__ == "__main__":
211
+ main()
collect_data_async.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Асинхронный сбор данных с ru.wikipedia.org для fine-tuning.
3
+ Ускоряет исходный [collect_data.py](collect_data.py) за счет конкурентной загрузки картинок,
4
+ но оставляет API-запросы к Wikipedia достаточно бережными.
5
+
6
+ Установка:
7
+ pip install aiohttp tqdm
8
+
9
+ Примеры:
10
+ python collect_data_async.py
11
+ python collect_data_async.py --max-total 10000 --max-depth 2 --resume
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import asyncio
18
+ import hashlib
19
+ import json
20
+ from pathlib import Path
21
+ from typing import Any, AsyncIterator, TextIO
22
+ from urllib.parse import unquote
23
+
24
+ import aiohttp
25
+ from tqdm import tqdm
26
+
27
+ API_URL = "https://ru.wikipedia.org/w/api.php"
28
+ HEADERS = {
29
+ # Укажи свои контакты при желании; для Wikimedia лучше честный bot UA, а не браузерный.
30
+ "User-Agent": "ML2HomeworkCollector/1.0 (educational project; contact: local-run)",
31
+ "Accept-Encoding": "gzip, deflate",
32
+ }
33
+ REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=30)
34
+
35
+ CATEGORIES = [
36
+ "Категория:Кошки (род)",
37
+ "Категория:Породы собак",
38
+ "Категория:Совообразные",
39
+ "Категория:Попугаеобразные",
40
+ "Категория:Ястребообразные",
41
+ "Категория:Бабочки",
42
+ "Категория:Жуки",
43
+ "Категория:Пресноводные рыбы",
44
+ "Категория:Акулы",
45
+ "Категория:Съедобные грибы",
46
+ "Категория:Ядовитые грибы",
47
+ "Категория:Фрукты",
48
+ "Категория:Овощи",
49
+ "Категория:Ягоды",
50
+ "Категория:Орехи",
51
+ "Категория:Хвойные",
52
+ "Категория:Цветковые растения",
53
+ "Категория:Кактусовые",
54
+ "Категория:Легковые автомобили",
55
+ "Категория:Мотоциклы",
56
+ "Категория:Вертолёты",
57
+ "Категория:Самолёты",
58
+ "Категория:Танки",
59
+ "Категория:Струнные музыкальные инструменты",
60
+ "Категория:Духовые музыкальные инструменты",
61
+ "Категория:Ударные музыкальные инструменты",
62
+ "Категория:Супы",
63
+ "Категория:Салаты",
64
+ "Категория:Пирожные",
65
+ "Категория:Мосты России",
66
+ "Категория:Мосты Европы",
67
+ "Категория:Замки Европы",
68
+ "Категория:Храмы России",
69
+ "Категория:Небоскрёбы",
70
+ "Категория:Маяки",
71
+ "Категория:Вулканы",
72
+ "Категория:Озёра России",
73
+ "Категория:Водопады",
74
+ "Категория:Холодное оружие",
75
+ "Категория:Огнестрельное оружие",
76
+ "Категория:Минералы",
77
+ "Категория:Драгоценные камни",
78
+ "Категория:Монеты",
79
+ "Категория:Флаги государств",
80
+ ]
81
+
82
+ SKIP_IMAGE_EXTENSIONS = {".svg", ".gif", ".ogg", ".ogv", ".webm", ".pdf", ".djvu"}
83
+
84
+ DATA_DIR = Path("data")
85
+ IMAGES_DIR = DATA_DIR / "images"
86
+ METADATA_FILE = DATA_DIR / "metadata.jsonl"
87
+ CHECKPOINT_FILE = DATA_DIR / "checkpoint.json"
88
+
89
+
90
+ class AsyncCollector:
91
+ def __init__(self, max_total: int, max_depth: int, resume: bool):
92
+ self.max_total = max_total
93
+ self.max_depth = max_depth
94
+ self.resume = resume
95
+ self.collected: set[str] = set()
96
+ self.session: aiohttp.ClientSession | None = None
97
+ self.meta_f: TextIO | None = None
98
+ self.pbar: tqdm | None = None
99
+
100
+ # API лучше не долбить параллельно; ускорение в основном будет на картинках.
101
+ self.api_sem = asyncio.Semaphore(1)
102
+ self.img_sem = asyncio.Semaphore(8)
103
+
104
+ async def init(self) -> None:
105
+ IMAGES_DIR.mkdir(parents=True, exist_ok=True)
106
+ connector = aiohttp.TCPConnector(limit=16)
107
+ self.session = aiohttp.ClientSession(
108
+ headers=HEADERS,
109
+ connector=connector,
110
+ timeout=REQUEST_TIMEOUT,
111
+ )
112
+
113
+ if self.resume and CHECKPOINT_FILE.exists():
114
+ with open(CHECKPOINT_FILE, encoding="utf-8") as f:
115
+ self.collected = set(json.load(f).get("collected_titles", []))
116
+
117
+ mode = "a" if self.resume and METADATA_FILE.exists() else "w"
118
+ self.meta_f = open(METADATA_FILE, mode, encoding="utf-8")
119
+ self.pbar = tqdm(total=self.max_total, initial=len(self.collected), desc="Collecting")
120
+
121
+ async def close(self) -> None:
122
+ if self.session is not None:
123
+ await self.session.close()
124
+ if self.meta_f is not None:
125
+ self.meta_f.close()
126
+ if self.pbar is not None:
127
+ self.pbar.close()
128
+
129
+ def save_checkpoint(self) -> None:
130
+ with open(CHECKPOINT_FILE, "w", encoding="utf-8") as f:
131
+ json.dump({"collected_titles": list(self.collected)}, f, ensure_ascii=False)
132
+
133
+ async def api_query(self, **params: Any) -> dict[str, Any]:
134
+ if self.session is None:
135
+ raise RuntimeError("Session is not initialized")
136
+
137
+ normalized_params: dict[str, str | int | float] = {
138
+ "format": "json",
139
+ "action": "query",
140
+ }
141
+ for key, value in params.items():
142
+ if isinstance(value, bool):
143
+ normalized_params[key] = "1" if value else "0"
144
+ elif isinstance(value, (str, int, float)):
145
+ normalized_params[key] = value
146
+ else:
147
+ normalized_params[key] = str(value)
148
+
149
+ async with self.api_sem:
150
+ await asyncio.sleep(0.05)
151
+ for attempt in range(4):
152
+ try:
153
+ async with self.session.get(API_URL, params=normalized_params) as resp:
154
+ if resp.status in (403, 429):
155
+ wait = int(resp.headers.get("Retry-After", 5 * (attempt + 1)))
156
+ tqdm.write(f"API limited ({resp.status}), sleeping {wait}s")
157
+ await asyncio.sleep(wait)
158
+ continue
159
+ resp.raise_for_status()
160
+ return await resp.json()
161
+ except Exception as e:
162
+ if attempt == 3:
163
+ tqdm.write(f"API Error: {e}")
164
+ return {}
165
+ await asyncio.sleep(1.5 * (attempt + 1))
166
+ return {}
167
+
168
+ async def download_image(self, url: str, save_path: Path) -> bool:
169
+ if self.session is None:
170
+ raise RuntimeError("Session is not initialized")
171
+ if save_path.exists():
172
+ return True
173
+
174
+ async with self.img_sem:
175
+ for attempt in range(3):
176
+ try:
177
+ async with self.session.get(url) as resp:
178
+ if resp.status in (403, 429):
179
+ wait = int(resp.headers.get("Retry-After", 3 * (attempt + 1)))
180
+ await asyncio.sleep(wait)
181
+ continue
182
+ resp.raise_for_status()
183
+ content = await resp.read()
184
+ with open(save_path, "wb") as f:
185
+ f.write(content)
186
+ return True
187
+ except Exception:
188
+ if attempt == 2:
189
+ return False
190
+ await asyncio.sleep(1 + attempt)
191
+ return False
192
+
193
+ async def iter_category_pages(self, category: str, max_per_category: int) -> AsyncIterator[str]:
194
+ visited_cats: set[str] = set()
195
+ count = 0
196
+
197
+ async def _crawl(cat: str, depth: int) -> AsyncIterator[str]:
198
+ nonlocal count
199
+ if depth > self.max_depth or cat in visited_cats or count >= max_per_category:
200
+ return
201
+ visited_cats.add(cat)
202
+
203
+ cmcontinue: str | None = None
204
+ subcats: list[str] = []
205
+
206
+ while count < max_per_category:
207
+ params: dict[str, Any] = {
208
+ "list": "categorymembers",
209
+ "cmtitle": cat,
210
+ "cmlimit": 50,
211
+ "cmtype": "page|subcat",
212
+ }
213
+ if cmcontinue:
214
+ params["cmcontinue"] = cmcontinue
215
+
216
+ data = await self.api_query(**params)
217
+ members = data.get("query", {}).get("categorymembers", [])
218
+ if not members and "error" in data:
219
+ return
220
+
221
+ for member in members:
222
+ if count >= max_per_category:
223
+ return
224
+ if member.get("ns") == 0:
225
+ title = member.get("title")
226
+ if isinstance(title, str):
227
+ count += 1
228
+ yield title
229
+ elif member.get("ns") == 14:
230
+ title = member.get("title")
231
+ if isinstance(title, str):
232
+ subcats.append(title)
233
+
234
+ cmcontinue = data.get("continue", {}).get("cmcontinue")
235
+ if not cmcontinue:
236
+ break
237
+
238
+ for subcat in subcats:
239
+ if count >= max_per_category:
240
+ return
241
+ async for title in _crawl(subcat, depth + 1):
242
+ yield title
243
+
244
+ async for title in _crawl(category, 0):
245
+ yield title
246
+
247
+ async def process_batch(self, batch: list[str], category: str) -> int:
248
+ data = await self.api_query(
249
+ titles="|".join(batch),
250
+ prop="extracts|pageimages",
251
+ exintro=1,
252
+ explaintext=1,
253
+ exsectionformat="plain",
254
+ piprop="thumbnail",
255
+ pithumbsize=512,
256
+ pilimit="max",
257
+ )
258
+ pages = data.get("query", {}).get("pages", {})
259
+
260
+ tasks: list[asyncio.Task[bool]] = []
261
+ records: list[dict[str, str]] = []
262
+
263
+ for page_id, page in pages.items():
264
+ try:
265
+ if int(page_id) < 0:
266
+ continue
267
+ except Exception:
268
+ continue
269
+
270
+ title = page.get("title", "")
271
+ if not isinstance(title, str) or title in self.collected:
272
+ continue
273
+
274
+ extract = page.get("extract", "")
275
+ thumb = page.get("thumbnail", {})
276
+ image_url = thumb.get("source", "") if isinstance(thumb, dict) else ""
277
+
278
+ if not isinstance(extract, str) or len(extract.strip()) < 50:
279
+ continue
280
+ if not isinstance(image_url, str) or not image_url:
281
+ continue
282
+
283
+ ext = Path(unquote(image_url)).suffix.lower().split("?")[0]
284
+ if ext in SKIP_IMAGE_EXTENSIONS:
285
+ continue
286
+
287
+ safe_name = hashlib.md5(title.encode("utf-8")).hexdigest()[:12]
288
+ final_ext = ext if ext and len(ext) <= 5 else ".jpg"
289
+ img_path = IMAGES_DIR / f"{safe_name}{final_ext}"
290
+
291
+ records.append(
292
+ {
293
+ "title": title,
294
+ "text": extract.strip(),
295
+ "image_path": str(img_path),
296
+ "image_url": image_url,
297
+ "category": category,
298
+ }
299
+ )
300
+ tasks.append(asyncio.create_task(self.download_image(image_url, img_path)))
301
+
302
+ if not tasks:
303
+ return 0
304
+
305
+ results = await asyncio.gather(*tasks)
306
+
307
+ if self.meta_f is None or self.pbar is None:
308
+ raise RuntimeError("Output files are not initialized")
309
+
310
+ added = 0
311
+ for record, success in zip(records, results):
312
+ if not success:
313
+ continue
314
+ self.meta_f.write(json.dumps(record, ensure_ascii=False) + "\n")
315
+ self.collected.add(record["title"])
316
+ added += 1
317
+ self.pbar.update(1)
318
+
319
+ self.meta_f.flush()
320
+ return added
321
+
322
+ async def collect_from_category(self, category: str, limit: int) -> int:
323
+ cat_count = 0
324
+ batch: list[str] = []
325
+
326
+ async for title in self.iter_category_pages(category, limit * 3):
327
+ if cat_count >= limit or len(self.collected) >= self.max_total:
328
+ break
329
+ if title in self.collected:
330
+ continue
331
+
332
+ batch.append(title)
333
+ if len(batch) >= 50:
334
+ cat_count += await self.process_batch(batch, category)
335
+ batch = []
336
+
337
+ if batch and cat_count < limit and len(self.collected) < self.max_total:
338
+ cat_count += await self.process_batch(batch, category)
339
+
340
+ return cat_count
341
+
342
+ async def run(self) -> None:
343
+ await self.init()
344
+ try:
345
+ base_per_cat = self.max_total // len(CATEGORIES)
346
+ tqdm.write(f"Pass 1: up to {base_per_cat} per category ({len(CATEGORIES)} categories)")
347
+
348
+ cat_stats: dict[str, int] = {}
349
+ for category in CATEGORIES:
350
+ if len(self.collected) >= self.max_total:
351
+ break
352
+ tqdm.write(f"\n📂 {category}")
353
+ n = await self.collect_from_category(category, base_per_cat)
354
+ cat_stats[category] = n
355
+ tqdm.write(f" ✓ {n} pairs")
356
+ self.save_checkpoint()
357
+
358
+ remaining = self.max_total - len(self.collected)
359
+ if remaining > 0:
360
+ big_cats = sorted(cat_stats, key=lambda c: cat_stats[c], reverse=True)
361
+ extra_per_cat = remaining // min(len(big_cats), 10) + 50
362
+
363
+ tqdm.write(f"\nPass 2: collecting {remaining} more from largest categories")
364
+ for category in big_cats:
365
+ if len(self.collected) >= self.max_total:
366
+ break
367
+ tqdm.write(f"\n📂 {category} (extra)")
368
+ n = await self.collect_from_category(category, extra_per_cat)
369
+ tqdm.write(f" ✓ {n} extra pairs")
370
+ self.save_checkpoint()
371
+ finally:
372
+ await self.close()
373
+
374
+ print(f"\nDone! Collected {len(self.collected)} pairs.")
375
+ print(f"Images: {IMAGES_DIR}")
376
+ print(f"Metadata: {METADATA_FILE}")
377
+
378
+
379
+ def main() -> None:
380
+ parser = argparse.ArgumentParser(description="Collect Wikipedia image-text pairs (async)")
381
+ parser.add_argument("--max-total", type=int, default=10000, help="Total pairs to collect")
382
+ parser.add_argument("--max-depth", type=int, default=2, help="Max category recursion depth")
383
+ parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
384
+ args = parser.parse_args()
385
+
386
+ collector = AsyncCollector(args.max_total, args.max_depth, args.resume)
387
+ asyncio.run(collector.run())
388
+
389
+
390
+ if __name__ == "__main__":
391
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ jinja2
5
+ torch
6
+ transformers
7
+ peft
8
+ pillow