import json import os import shutil import matplotlib.pyplot as plt import numpy as np import requests from sklearn.model_selection import train_test_split ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) PROCESSED_DIR = os.path.join(ROOT_DIR, 'dataset', 'processed') QUICKDRAW_DIR = os.path.join(ROOT_DIR, 'dataset', 'quickdraw') MODEL_DIR = os.path.join(ROOT_DIR, 'model') os.makedirs(PROCESSED_DIR, exist_ok=True) CATEGORIES = { 'Animals': [ 'bear', 'bee', 'butterfly', 'cat', 'cow', 'crab', 'camel', 'dog', 'dolphin', 'duck', 'elephant', 'fish', 'flamingo', 'frog', 'giraffe', 'hedgehog', 'horse', 'kangaroo', 'lion', 'monkey', 'octopus', 'owl', 'panda', 'penguin', 'pig', 'rabbit', 'shark', 'sheep', 'snake', 'spider', 'tiger', 'whale', 'zebra', ], 'Food': [ 'apple', 'banana', 'birthday cake', 'bread', 'carrot', 'cookie', 'donut', 'grapes', 'hamburger', 'hot dog', 'ice cream', 'broccoli', 'mushroom', 'pear', 'pineapple', 'pizza', 'strawberry', 'watermelon', ], 'Vehicles': [ 'airplane', 'bicycle', 'bus', 'car', 'firetruck', 'helicopter', 'motorbike', 'cruise ship', 'sailboat', 'submarine', 'train', 'truck', ], 'Objects': [ 'backpack', 'book', 'camera', 'chair', 'clock', 'computer', 'cup', 'drums', 'fork', 'guitar', 'hammer', 'hat', 'key', 'knife', 'lantern', 'microphone', 'pencil', 'piano', 'scissors', 'shoe', 'sword', 'umbrella', ], 'Nature': [ 'cloud', 'campfire', 'flower', 'leaf', 'lightning', 'moon', 'mountain', 'rainbow', 'snowflake', 'star', 'sun', 'tree', ], 'Buildings': [ 'bridge', 'castle', 'door', 'fence', 'house', 'lighthouse', 'windmill', ], 'Body': [ 'ear', 'eye', 'face', 'hand', 'nose', 'tooth', ], 'Misc': [ 'circle', 'crown', 'diamond', 'bowtie', 'hot air balloon', 'lollipop', 'skull', 'stop sign', 'tornado', 'cactus', ], } CLASSES = [cls for group in CATEGORIES.values() for cls in group] BASE_URL = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/' def download_data(classes, data_dir=QUICKDRAW_DIR): os.makedirs(data_dir, exist_ok=True) for class_name in classes: file_name = f"{class_name}.npy" path = os.path.join(data_dir, file_name) if os.path.exists(path): print(f"Already exists: {file_name}") continue url = BASE_URL + class_name.replace(' ', '%20') + ".npy" try: r = requests.get(url, timeout=30) r.raise_for_status() with open(path, 'wb') as f: f.write(r.content) print(f"Downloaded: {file_name}") except Exception as e: # pylint: disable=broad-exception-caught print(f"Failed to download {class_name}: {e}") def load_data(classes, max_samples_per_class=15000): x_data, y_data, available_classes = [], [], [] for class_name in classes: file_path = os.path.join(QUICKDRAW_DIR, f"{class_name}.npy") if not os.path.exists(file_path): print(f"Missing file: {class_name}") continue data = np.load(file_path) if data.shape[0] > max_samples_per_class: indices = np.random.choice(data.shape[0], max_samples_per_class, replace=False) data = data[indices] label_idx = len(available_classes) x_data.append(data) y_data.extend([label_idx] * data.shape[0]) available_classes.append(class_name) print(f"Loaded {data.shape[0]} samples for '{class_name}'") if not x_data: raise RuntimeError("No data loaded. Check download step.") x_out = np.concatenate(x_data, axis=0).reshape(-1, 28, 28, 1).astype(np.float32) / 255.0 y_out = np.array(y_data) return x_out, y_out, available_classes def visualize_samples(x_data, y_data, classes, samples_per_class=5): _, axes = plt.subplots( # pylint: disable=too-many-function-args len(classes), samples_per_class, figsize=(samples_per_class * 2, len(classes) * 2) ) for class_idx, class_name in enumerate(classes): indices = np.where(y_data == class_idx)[0] samples = np.random.choice(indices, samples_per_class, replace=False) for i, idx in enumerate(samples): ax = axes[class_idx, i] ax.imshow(x_data[idx].squeeze(), cmap='gray') ax.axis('off') if i == 0: ax.set_title(class_name, fontsize=10) plt.tight_layout() output_path = os.path.join(PROCESSED_DIR, "sample_drawings.png") plt.savefig(output_path, dpi=150) plt.close() print(f"Saved sample visualization to: {output_path}") def split_and_save(x_data, y_data): x_temp, x_test, y_temp, y_test = train_test_split( x_data, y_data, test_size=0.2, stratify=y_data, random_state=42 ) x_train, x_val, y_train, y_val = train_test_split( x_temp, y_temp, test_size=0.125, stratify=y_temp, random_state=42 ) np.save(os.path.join(PROCESSED_DIR, 'X_train.npy'), x_train) np.save(os.path.join(PROCESSED_DIR, 'X_val.npy'), x_val) np.save(os.path.join(PROCESSED_DIR, 'X_test.npy'), x_test) np.save(os.path.join(PROCESSED_DIR, 'y_train.npy'), y_train) np.save(os.path.join(PROCESSED_DIR, 'y_val.npy'), y_val) np.save(os.path.join(PROCESSED_DIR, 'y_test.npy'), y_test) print(f"Saved datasets: {x_train.shape[0]} train, {x_val.shape[0]} val, {x_test.shape[0]} test") def save_class_mappings(classes): os.makedirs(MODEL_DIR, exist_ok=True) class_to_idx = {cls: i for i, cls in enumerate(classes)} idx_to_class = dict(enumerate(classes)) with open(os.path.join(PROCESSED_DIR, 'class_name_to_index.json'), 'w', encoding='utf-8') as f: json.dump(class_to_idx, f, indent=2) with open(os.path.join(PROCESSED_DIR, 'index_to_class_name.json'), 'w', encoding='utf-8') as f: json.dump(idx_to_class, f, indent=2) shutil.copyfile( os.path.join(PROCESSED_DIR, 'index_to_class_name.json'), os.path.join(MODEL_DIR, 'classes.json') ) print("Saved class mappings") def update_readme_classes(available_classes): readme_path = os.path.join(ROOT_DIR, 'README.md') available_set = set(available_classes) lines = [] total = len(available_classes) lines.append(f"{total} categories across {len(CATEGORIES)} groups:") for group, members in CATEGORIES.items(): present = [m for m in members if m in available_set] if present: lines.append(f"**{group}**: {', '.join(present)}") new_section = '\n'.join(lines) with open(readme_path, 'r', encoding='utf-8') as f: content = f.read() import re pattern = r'(## Supported Categories\n\n).*?(\n## )' replacement = r'\g<1>' + new_section + r'\n\2' new_content = re.sub(pattern, replacement, content, flags=re.DOTALL) with open(readme_path, 'w', encoding='utf-8') as f: f.write(new_content) print(f"Updated README.md with {total} classes") def main(): print("Preparing QuickDraw dataset...") download_data(CLASSES) x_data, y_data, available_classes = load_data(CLASSES) visualize_samples(x_data, y_data, available_classes) split_and_save(x_data, y_data) save_class_mappings(available_classes) update_readme_classes(available_classes) print("Done. Run scripts/train_model.py to train the model.") if __name__ == "__main__": main()