Spaces:
Sleeping
Sleeping
| import csv | |
| from pathlib import Path | |
| import time | |
| import json | |
| import os, io | |
| import aiofiles | |
| import aiohttp | |
| import asyncio | |
| from PIL import Image | |
| from abc import ABC, abstractmethod | |
| from concurrent.futures import ProcessPoolExecutor | |
| from dataclasses import asdict, dataclass | |
| class ProcessState: | |
| urls_processed: int = 0 | |
| images_downloaded: int = 0 | |
| images_saved: int = 0 | |
| images_resized: int = 0 | |
| class ImageProcessor(ABC): | |
| def process(self, image: bytes, filename: str) -> None: | |
| pass | |
| class ImageSaver(ImageProcessor): | |
| async def process(self, image: bytes, filename: str) -> None: | |
| async with aiofiles.open(filename, "wb") as f: | |
| await f.write(image) | |
| def resize_image(image: bytes, filename: str, max_size: int = 300) -> None: | |
| with Image.open(io.BytesIO(image)) as img: | |
| img.thumbnail((max_size, max_size)) | |
| img.save(filename, optimize=True, quality=85) | |
| class RateLimiter: | |
| """ | |
| High-Level Concept: The Token Bucket Algorithm | |
| ============================================== | |
| The Rate_Limiter class implements what's known as the "Token Bucket" algorithm. Imagine you have a bucket that can hold a certain number of tokens. Here's how it works: | |
| The bucket is filled with tokens at a constant rate. | |
| When you want to perform an action (in our case, make an API request), you need to take a token from the bucket. | |
| If there's a token available, you can perform the action immediately. | |
| If there are no tokens, you have to wait until a new token is added to the bucket. | |
| The bucket has a maximum capacity, so tokens don't accumulate indefinitely when not used. | |
| This mechanism allows for both steady-state rate limiting and handling short bursts of activity. | |
| In the constructor: | |
| =================== | |
| rate: is how many tokens we add per time period (e.g., 10 tokens per second) | |
| per: is the time period (usually 1 second) | |
| burst: is the bucket size (maximum number of tokens) | |
| We start with a full bucket (self.tokens = burst) | |
| We note the current time (self.updated_at) | |
| Logic: | |
| ====== | |
| 1. Calculate how much time has passed since we last updated the token count. | |
| 2. Add tokens based on the time passed and our rate: | |
| self.tokens += time_passed * (self.rate / self.per) | |
| 3. If we've added too many tokens, cap it at our maximum (burst size). | |
| 4. Update our "last updated" time. | |
| 5. If we have at least one token: | |
| Remove a token (self.tokens -= 1) | |
| Return immediately, allowing the API call to proceed | |
| 6. If we don't have a token: | |
| Calculate how long we need to wait for the next token | |
| Sleep for that duration | |
| Let's walk through an example: | |
| ============================== | |
| Suppose we set up our RateLimiter like this: | |
| Copylimiter = RateLimiter(rate=10, per=1, burst=10) | |
| This means: | |
| - We allow 10 requests per second on average | |
| - We can burst up to 10 requests at once | |
| - After the burst, we'll be limited to 1 request every 0.1 seconds | |
| Now, imagine a sequence of API calls: | |
| 1. The first 10 calls will happen immediately (burst capacity) | |
| 2. The 11th call will wait for 0.1 seconds (time to generate 1 token) | |
| 3. Subsequent calls will each wait about 0.1 seconds | |
| If there's a pause in API calls, tokens will accumulate (up to the burst limit), allowing for another burst of activity. | |
| This mechanism ensures that: | |
| 1. We respect the average rate limit (10 per second in this example) | |
| 2. We can handle short bursts of activity (up to 10 at once) | |
| 3. We smoothly regulate requests when operating at capacity | |
| """ | |
| def __init__(self, rate: float, per: float = 1.0, burst: int = 1): | |
| self.rate = rate | |
| self.per = per | |
| self.burst = burst | |
| self.tokens = burst | |
| self.updated_at = time.monotonic() | |
| async def wait(self): | |
| while True: | |
| now = time.monotonic() | |
| time_passed = now - self.updated_at | |
| self.tokens += time_passed * (self.rate / self.per) | |
| if self.tokens > self.burst: | |
| self.tokens = self.burst | |
| self.updated_at = now | |
| if self.tokens >= 1: | |
| self.tokens -= 1 | |
| return | |
| else: | |
| await asyncio.sleep((1 - self.tokens) / (self.rate / self.per)) | |
| class ImagePipeline: | |
| def __init__( | |
| self, | |
| txt_file: str, | |
| loop: asyncio.AbstractEventLoop, | |
| max_concurrent_downloads: int = 10, | |
| max_workers: int = max(os.cpu_count() - 4, 4), | |
| rate_limit: float = 10, | |
| rate_limit_period: float = 1, | |
| downloaded_images_dir: str = "", | |
| ): | |
| self.txt_file = txt_file | |
| self.loop = loop | |
| self.url_queue = asyncio.Queue(maxsize=1000) | |
| self.image_queue = asyncio.Queue(maxsize=100) | |
| self.semaphore = asyncio.Semaphore(max_concurrent_downloads) | |
| self.state = ProcessState() | |
| self.state_file = "pipeline_state.json" | |
| self.saver = ImageSaver() | |
| self.process_pool = ProcessPoolExecutor(max_workers=max_workers) | |
| self.rate_limiter = RateLimiter( | |
| rate=rate_limit, per=rate_limit_period, burst=max_concurrent_downloads | |
| ) | |
| self.downloaded_images_dir = Path(downloaded_images_dir) | |
| async def url_feeder(self): | |
| try: | |
| print(f"Starting to read URLs from {self.txt_file}") | |
| async with aiofiles.open(self.txt_file, mode="r") as f: | |
| line_number = 0 | |
| async for line in f: | |
| line_number += 1 | |
| if line_number <= self.state.urls_processed: | |
| continue | |
| url = line.strip() | |
| if url: # Skip empty lines | |
| await self.url_queue.put(url) | |
| self.state.urls_processed += 1 | |
| # Check if we need to wait for the queue to have space | |
| if self.url_queue.qsize() >= self.url_queue.maxsize - 1: | |
| await asyncio.sleep(0.1) | |
| except Exception as e: | |
| print(f"Error in url_feeder: {e}") | |
| finally: | |
| await self.url_queue.put(None) | |
| async def image_downloader(self): | |
| print("Starting image downloader") | |
| async with aiohttp.ClientSession() as session: | |
| while True: | |
| url = await self.url_queue.get() | |
| if url is None: | |
| print("Finished downloading images") | |
| await self.image_queue.put(None) | |
| break | |
| try: | |
| await self.rate_limiter.wait() # Wait for rate limit | |
| async with self.semaphore: | |
| async with session.get(url) as response: | |
| if response.status == 200: | |
| image = await response.read() | |
| await self.image_queue.put((image, url)) | |
| self.state.images_downloaded += 1 | |
| if self.state.images_downloaded % 100 == 0: | |
| print( | |
| f"Downloaded {self.state.images_downloaded} images" | |
| ) | |
| except Exception as e: | |
| print(f"Error downloading {url}: {e}") | |
| finally: | |
| self.url_queue.task_done() | |
| async def image_processor(self): | |
| print("Starting image processor") | |
| while True: | |
| item = await self.image_queue.get() | |
| if item is None: | |
| print("Finished processing images") | |
| break | |
| image, url = item | |
| filename = os.path.basename(url) | |
| if not filename.lower().endswith((".png", ".jpg", ".jpeg")): | |
| filename += ".png" | |
| try: | |
| # Save the original image | |
| await self.saver.process( | |
| image, str(self.downloaded_images_dir / f"original_{filename}") | |
| ) | |
| self.state.images_saved += 1 | |
| if self.state.images_resized % 100 == 0: | |
| print(f"Processed {self.state.images_resized} images") | |
| # Resize the image using the process pool | |
| # loop = asyncio.get_running_loop() | |
| await self.loop.run_in_executor( | |
| self.process_pool, | |
| resize_image, | |
| image, | |
| str(self.downloaded_images_dir / f"resized_{filename}"), | |
| ) | |
| self.state.images_resized += 1 | |
| except Exception as e: | |
| print(f"Error processing {url}: {e}") | |
| finally: | |
| self.image_queue.task_done() | |
| def save_state(self): | |
| with open(self.state_file, "w") as f: | |
| json.dump(asdict(self.state), f) | |
| def load_state(self): | |
| if os.path.exists(self.state_file): | |
| with open(self.state_file, "r") as f: | |
| self.state = ProcessState(**json.load(f)) | |
| async def run(self): | |
| print("Starting pipeline") | |
| self.load_state() | |
| print(f"Loaded state: {self.state}") | |
| tasks = [ | |
| asyncio.create_task(self.url_feeder()), | |
| asyncio.create_task(self.image_downloader()), | |
| asyncio.create_task(self.image_processor()), | |
| ] | |
| try: | |
| await asyncio.gather(*tasks) | |
| except Exception as e: | |
| print(f"Pipeline error: {e}") | |
| finally: | |
| self.save_state() | |
| print(f"Final state: {self.state}") | |
| self.process_pool.shutdown() | |
| print("Pipeline finished") | |
| if __name__ == "__main__": | |
| from pathlib import Path | |
| PROJECT_ROOT = Path(__file__).resolve().parent | |
| loop = asyncio.get_event_loop() | |
| text_file = PROJECT_ROOT / "data/image_urls.txt" | |
| if not text_file.exists(): | |
| import pandas as pd | |
| dataframe = pd.read_csv(PROJECT_ROOT / "data/photos.tsv000", sep="\t") | |
| num_image_urls = len(dataframe) | |
| print(f"Number of image urls: {num_image_urls}") | |
| with open(text_file, "w") as f: | |
| for url in dataframe["photo_image_url"]: | |
| f.write(url + "\n") | |
| print("Started downloading images") | |
| pipeline = ImagePipeline( | |
| txt_file=text_file, | |
| loop=loop, | |
| rate_limit=100, | |
| rate_limit_period=1, | |
| downloaded_images_dir=str(PROJECT_ROOT / "data/data/images"), | |
| ) | |
| # asyncio.run(pipeline.run()) | |
| loop.run_until_complete(pipeline.run()) | |
| print("Finished downloading images") | |