Spaces:
Sleeping
Sleeping
File size: 7,449 Bytes
ece157f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """
Сбор данных с ru.wikipedia.org для fine-tuning CLIP.
Собирает случайные статьи с изображениями — пары (картинка, текст).
Использование:
python collect_data.py
python collect_data.py --max-total 10000
python collect_data.py --max-total 10000 --resume
"""
import argparse
import hashlib
import json
import time
from pathlib import Path
from urllib.parse import unquote
import requests
from tqdm import tqdm
API_URL = "https://ru.wikipedia.org/w/api.php"
SESSION = requests.Session()
SESSION.headers.update({
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
"Referer": "https://ru.wikipedia.org/",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8",
})
SKIP_IMAGE_EXTENSIONS = {".svg", ".gif", ".ogg", ".ogv", ".webm", ".pdf", ".djvu"}
DATA_DIR = Path("data")
IMAGES_DIR = DATA_DIR / "images"
METADATA_FILE = DATA_DIR / "metadata.jsonl"
CHECKPOINT_FILE = DATA_DIR / "checkpoint.json"
def api_query(**params):
"""Запрос к MediaWiki API с rate limiting."""
params.setdefault("format", "json")
params.setdefault("action", "query")
time.sleep(0.1)
resp = SESSION.get(API_URL, params=params, timeout=30)
resp.raise_for_status()
return resp.json()
def get_random_titles(count: int = 20) -> list[str]:
"""Получить случайные заголовки статей (namespace 0 = основные статьи)."""
data = api_query(list="random", rnnamespace=0, rnlimit=count)
return [p["title"] for p in data.get("query", {}).get("random", [])]
def get_article_data(titles: list[str]) -> dict:
"""Получить extract + thumbnail для пачки статей (до 20)."""
data = api_query(
titles="|".join(titles),
prop="extracts|pageimages",
exintro=True,
explaintext=True,
exsectionformat="plain",
piprop="thumbnail",
pithumbsize=512,
pilimit="max",
)
pages = data.get("query", {}).get("pages", {})
results = {}
for page_id, page in pages.items():
if int(page_id) < 0:
continue
title = page.get("title", "")
extract = page.get("extract", "").strip()
thumb = page.get("thumbnail", {})
image_url = thumb.get("source", "")
results[title] = {"extract": extract, "image_url": image_url}
return results
def download_image(url: str, save_path: Path, max_retries: int = 3) -> bool:
"""Скачать изображение с retry и exponential backoff."""
for attempt in range(max_retries):
try:
time.sleep(0.2 + attempt * 2)
resp = SESSION.get(url, timeout=30, stream=True)
if resp.status_code == 429:
wait = int(resp.headers.get("Retry-After", 5 * (attempt + 1)))
tqdm.write(f" ⏳ Rate limited, waiting {wait}s...")
time.sleep(wait)
continue
resp.raise_for_status()
with open(save_path, "wb") as f:
for chunk in resp.iter_content(8192):
f.write(chunk)
return True
except requests.exceptions.HTTPError as e:
if "429" in str(e) and attempt < max_retries - 1:
time.sleep(5 * (attempt + 1))
continue
tqdm.write(f" ⚠ Download failed: {e}")
return False
except Exception as e:
tqdm.write(f" ⚠ Download failed: {e}")
return False
return False
def image_filename(title: str, url: str) -> str:
ext = Path(unquote(url)).suffix.lower().split("?")[0]
if not ext or len(ext) > 5:
ext = ".jpg"
safe_name = hashlib.md5(title.encode()).hexdigest()[:12]
return f"{safe_name}{ext}"
def load_checkpoint() -> set[str]:
if CHECKPOINT_FILE.exists():
with open(CHECKPOINT_FILE) as f:
return set(json.load(f).get("collected_titles", []))
return set()
def save_checkpoint(collected: set[str]):
with open(CHECKPOINT_FILE, "w") as f:
json.dump({"collected_titles": list(collected)}, f, ensure_ascii=False)
def main():
parser = argparse.ArgumentParser(description="Collect random Wikipedia image-text pairs")
parser.add_argument("--max-total", type=int, default=10000, help="Total pairs to collect")
parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
args = parser.parse_args()
IMAGES_DIR.mkdir(parents=True, exist_ok=True)
collected = load_checkpoint() if args.resume else set()
mode = "a" if args.resume and METADATA_FILE.exists() else "w"
total = len(collected)
skipped = 0
pbar = tqdm(total=args.max_total, initial=total, desc="Collecting")
with open(METADATA_FILE, mode, encoding="utf-8") as meta_f:
while total < args.max_total:
# Берём пачку случайных статей
random_titles = get_random_titles(20)
# Фильтруем уже собранные
new_titles = [t for t in random_titles if t not in collected]
if not new_titles:
continue
# Получаем данные статей
article_data = get_article_data(new_titles)
for title, info in article_data.items():
if total >= args.max_total:
break
if title in collected:
continue
extract = info["extract"]
image_url = info["image_url"]
# Пропуск статей без текста или картинки
if not extract or len(extract) < 50:
skipped += 1
continue
if not image_url:
skipped += 1
continue
# Пропуск не-фото форматов
ext = Path(unquote(image_url)).suffix.lower().split("?")[0]
if ext in SKIP_IMAGE_EXTENSIONS:
skipped += 1
continue
# Скачиваем
fname = image_filename(title, image_url)
img_path = IMAGES_DIR / fname
if not img_path.exists():
if not download_image(image_url, img_path):
skipped += 1
continue
record = {
"title": title,
"text": extract,
"image_path": str(img_path),
"image_url": image_url,
}
meta_f.write(json.dumps(record, ensure_ascii=False) + "\n")
meta_f.flush()
collected.add(title)
total += 1
pbar.update(1)
# Checkpoint каждые 100 статей
if total % 100 < 20:
save_checkpoint(collected)
pbar.set_postfix(skipped=skipped)
save_checkpoint(collected)
pbar.close()
print(f"\nDone! Collected {total} pairs (skipped {skipped} without image/text).")
print(f"Images: {IMAGES_DIR}")
print(f"Metadata: {METADATA_FILE}")
if __name__ == "__main__":
main()
|