semantic-image-search-api / scripts /create_sample_data.py
tharejasarthak's picture
Deploy fresh to HF
9508d8c
"""Create minimal sample data for quick demo without full COCO/Flickr30k download."""
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
def download_hf_sample(output_dir: Path, max_images: int = 50):
"""Download small sample from HuggingFace datasets (requires datasets lib)."""
try:
from datasets import load_dataset
import requests
from PIL import Image
from io import BytesIO
import concurrent.futures
import csv
except ImportError:
print("Install: pip install datasets requests Pillow")
sys.exit(1)
output_dir.mkdir(parents=True, exist_ok=True)
print("Loading COCO dataset stream...")
ds = load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", split="train", streaming=True)
# We will collect target URLs and Captions first, up to max_images
tasks = []
seen_urls = set()
# Curated images
curated = [
(
"https://images.unsplash.com/photo-1548199973-03cce0bbc87b?w=600&q=80",
"Two dogs playing in the snow, running and jumping."
),
(
"https://images.unsplash.com/photo-1554580665-9831c2063eef?w=600&q=80",
"A man selling ice cream or desserts from a small cart outdoors."
),
(
"https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=600&q=80",
"A cat sitting on a windowsill looking out."
)
]
for url, cap in curated:
tasks.append((url, cap))
seen_urls.add(url)
for ex in ds:
if len(tasks) >= max_images:
break
url = ex.get("URL")
cap = ex.get("TEXT")
if url and cap and url not in seen_urls:
tasks.append((url, cap))
seen_urls.add(url)
print(f"Collected {len(tasks)} target URLs. Starting parallel download...")
ids, paths, captions = [], [], []
downloaded = 0
def download_image(index, url, caption):
try:
r = requests.get(url, timeout=5)
if r.status_code == 200:
img = Image.open(BytesIO(r.content)).convert("RGB")
fname = f"sample_{index}.jpg"
out_path = output_dir / fname
# Resize image slightly to save disk space for large datasets
img.thumbnail((300, 300))
img.save(out_path, format="JPEG", quality=85)
return f"img_{index}", out_path, caption
except Exception:
pass
return None
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
future_to_url = {executor.submit(download_image, i, t[0], t[1]): t for i, t in enumerate(tasks)}
for future in concurrent.futures.as_completed(future_to_url):
result = future.result()
if result:
uid, path, cap = result
ids.append(uid)
paths.append(path)
captions.append(cap)
downloaded += 1
if downloaded % 100 == 0:
print(f"Downloaded {downloaded}/{len(tasks)} images")
# Write captions to CSV in the parent directory
csv_path = output_dir.parent / "captions.csv"
with open(csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
for path, cap in zip(paths, captions):
writer.writerow([path.name, cap])
print(f"Saved {len(ids)} captions to {csv_path}")
return list(zip(ids, paths, captions))
def create_placeholder_images(output_dir: Path, count: int = 20):
"""Create simple placeholder images (colored squares) for testing."""
try:
from PIL import Image
import random
except ImportError:
print("PIL required")
sys.exit(1)
output_dir.mkdir(parents=True, exist_ok=True)
items = []
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (128, 0, 128)]
labels = ["red", "green", "blue", "yellow", "purple"]
for i in range(count):
idx = i % len(colors)
img = Image.new("RGB", (224, 224), color=colors[idx])
fname = f"demo_{i}.jpg"
path = output_dir / fname
img.save(path)
items.append((f"img_{i}", path, labels[idx]))
return items
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output", type=Path, default=ROOT / "data" / "images")
parser.add_argument("--method", choices=["hf", "placeholder"], default="placeholder")
parser.add_argument("--count", type=int, default=50)
args = parser.parse_args()
out = args.output
if not out.is_absolute():
out = ROOT / out
out.mkdir(parents=True, exist_ok=True)
if args.method == "placeholder":
items = create_placeholder_images(out, count=args.count)
else:
items = download_hf_sample(out, max_images=args.count)
print(f"Created {len(items)} images in {out}")
return items
if __name__ == "__main__":
main()