Spaces:
Sleeping
Sleeping
| """ | |
| Hugging Face uploader class with rate limiting | |
| """ | |
| import os | |
| import asyncio | |
| from datetime import datetime, timedelta | |
| from typing import Dict, List, Optional, Tuple | |
| import aiohttp | |
| import logging | |
| from sqlalchemy.orm import Session | |
| from models import ( | |
| RateLimitLog, | |
| UploadQueue, | |
| UploadStatusEnum, | |
| UploadErrorLog, | |
| HFConfig, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class RateLimiter: | |
| """Rate limiter for Hugging Face uploads""" | |
| def __init__(self, max_uploads_per_hour: int = 128): | |
| """ | |
| Initialize rate limiter | |
| Args: | |
| max_uploads_per_hour: Maximum uploads allowed per hour | |
| """ | |
| self.max_uploads_per_hour = max_uploads_per_hour | |
| def get_current_hour_window(self) -> Tuple[datetime, datetime]: | |
| """Get current hour window (start and end times)""" | |
| now = datetime.utcnow() | |
| hour_start = now.replace(minute=0, second=0, microsecond=0) | |
| hour_end = hour_start + timedelta(hours=1) | |
| return hour_start, hour_end | |
| async def check_rate_limit(self, db: Session) -> Dict: | |
| """ | |
| Check if we can upload in current hour | |
| Args: | |
| db: Database session | |
| Returns: | |
| Dictionary with canUpload, remainingUploads, and resumeTime | |
| """ | |
| hour_start, hour_end = self.get_current_hour_window() | |
| # Find rate limit log for current hour | |
| rate_log = ( | |
| db.query(RateLimitLog) | |
| .filter(RateLimitLog.hour_start == hour_start) | |
| .first() | |
| ) | |
| if not rate_log: | |
| return { | |
| "can_upload": True, | |
| "remaining_uploads": self.max_uploads_per_hour, | |
| "resume_time": None, | |
| } | |
| if rate_log.limit_hit: | |
| return { | |
| "can_upload": False, | |
| "remaining_uploads": 0, | |
| "resume_time": rate_log.resume_time, | |
| } | |
| remaining = self.max_uploads_per_hour - rate_log.upload_count | |
| return { | |
| "can_upload": remaining > 0, | |
| "remaining_uploads": max(0, remaining), | |
| "resume_time": None, | |
| } | |
| async def increment_counter(self, db: Session) -> None: | |
| """ | |
| Increment upload counter for current hour | |
| Args: | |
| db: Database session | |
| """ | |
| hour_start, hour_end = self.get_current_hour_window() | |
| rate_log = ( | |
| db.query(RateLimitLog) | |
| .filter(RateLimitLog.hour_start == hour_start) | |
| .first() | |
| ) | |
| if rate_log: | |
| rate_log.upload_count += 1 | |
| else: | |
| rate_log = RateLimitLog( | |
| upload_count=1, | |
| hour_start=hour_start, | |
| hour_end=hour_end, | |
| limit_hit=False, | |
| ) | |
| db.add(rate_log) | |
| db.commit() | |
| async def mark_limit_hit(self, db: Session) -> None: | |
| """ | |
| Mark rate limit as hit for current hour | |
| Args: | |
| db: Database session | |
| """ | |
| hour_start, hour_end = self.get_current_hour_window() | |
| rate_log = ( | |
| db.query(RateLimitLog) | |
| .filter(RateLimitLog.hour_start == hour_start) | |
| .first() | |
| ) | |
| if rate_log: | |
| rate_log.limit_hit = True | |
| rate_log.resume_time = hour_end + timedelta(seconds=1) | |
| db.commit() | |
| class HFUploader: | |
| """Object-oriented Hugging Face uploader""" | |
| def __init__(self, hf_token: str, target_repo: str): | |
| """ | |
| Initialize uploader | |
| Args: | |
| hf_token: Hugging Face API token | |
| target_repo: Target repository ID (e.g., "samfred2/ALL2") | |
| """ | |
| self.hf_token = os.getenv("HF_TOKEN") | |
| self.target_repo = target_repo | |
| self.rate_limiter = RateLimiter(max_uploads_per_hour=128) | |
| async def upload_file( | |
| self, file_path: str, file_name: str, db: Session | |
| ) -> Dict: | |
| """ | |
| Upload a single file to Hugging Face | |
| Args: | |
| file_path: Path to file to upload | |
| file_name: Name of file in repository | |
| db: Database session | |
| Returns: | |
| Upload result dictionary | |
| """ | |
| try: | |
| # Check if file exists | |
| if not os.path.exists(file_path): | |
| return { | |
| "success": False, | |
| "file_name": file_name, | |
| "message": "File not found", | |
| "retryable": False, | |
| } | |
| # Read file | |
| with open(file_path, "rb") as f: | |
| file_content = f.read() | |
| # Create multipart upload | |
| url = f"https://huggingface.co/api/datasets/{self.target_repo}/upload" | |
| # Use aiohttp for async upload | |
| async with aiohttp.ClientSession() as session: | |
| # Create form data | |
| data = aiohttp.FormData() | |
| data.add_field( | |
| "files", | |
| file_content, | |
| filename=file_name, | |
| content_type="application/json", | |
| ) | |
| headers = {"Authorization": f"Bearer {self.hf_token}"} | |
| async with session.post(url, data=data, headers=headers) as response: | |
| # Handle 429 rate limit | |
| if response.status == 429: | |
| logger.warning(f"Rate limit hit (429) for {file_name}") | |
| await self.rate_limiter.mark_limit_hit(db) | |
| return { | |
| "success": False, | |
| "file_name": file_name, | |
| "message": "Rate limit hit (429). Will retry after 1 hour.", | |
| "status_code": 429, | |
| "retryable": True, | |
| } | |
| # Handle other errors | |
| if response.status != 200: | |
| error_text = await response.text() | |
| retryable = ( | |
| response.status >= 500 | |
| or response.status == 408 | |
| or response.status == 429 | |
| ) | |
| logger.error( | |
| f"Upload failed for {file_name}: HTTP {response.status}" | |
| ) | |
| return { | |
| "success": False, | |
| "file_name": file_name, | |
| "message": f"Upload failed: {response.reason}. {error_text}", | |
| "status_code": response.status, | |
| "retryable": retryable, | |
| } | |
| # Success | |
| await self.rate_limiter.increment_counter(db) | |
| logger.info(f"Successfully uploaded {file_name}") | |
| return { | |
| "success": True, | |
| "file_name": file_name, | |
| "message": "File uploaded successfully", | |
| } | |
| except Exception as e: | |
| logger.error(f"Upload error for {file_name}: {e}") | |
| return { | |
| "success": False, | |
| "file_name": file_name, | |
| "message": f"Upload error: {str(e)}", | |
| "retryable": True, | |
| } | |
| async def upload_files_batch( | |
| self, | |
| files: List[Dict], | |
| db: Session, | |
| batch_size: int = 10, | |
| ) -> Dict: | |
| """ | |
| Upload multiple files with rate limiting | |
| Args: | |
| files: List of dicts with 'id', 'file_name', 'file_path' | |
| db: Database session | |
| batch_size: Number of files to process before checking rate limit | |
| Returns: | |
| Upload batch result | |
| """ | |
| results = [] | |
| successful = 0 | |
| failed = 0 | |
| paused = False | |
| for i, file_info in enumerate(files): | |
| # Check rate limit before each upload | |
| rate_check = await self.rate_limiter.check_rate_limit(db) | |
| if not rate_check["can_upload"]: | |
| logger.info( | |
| f"Rate limit reached. Pausing uploads. Resume at: {rate_check['resume_time']}" | |
| ) | |
| paused = True | |
| break | |
| # Upload file | |
| result = await self.upload_file( | |
| file_info["file_path"], file_info["file_name"], db | |
| ) | |
| results.append(result) | |
| if result["success"]: | |
| successful += 1 | |
| # Update queue status | |
| queue_item = ( | |
| db.query(UploadQueue) | |
| .filter(UploadQueue.id == file_info["id"]) | |
| .first() | |
| ) | |
| if queue_item: | |
| queue_item.status = UploadStatusEnum.COMPLETED | |
| queue_item.uploaded_at = datetime.utcnow() | |
| db.commit() | |
| else: | |
| failed += 1 | |
| # Update queue status | |
| queue_item = ( | |
| db.query(UploadQueue) | |
| .filter(UploadQueue.id == file_info["id"]) | |
| .first() | |
| ) | |
| if queue_item: | |
| queue_item.status = UploadStatusEnum.FAILED | |
| queue_item.failure_reason = result["message"] | |
| queue_item.retry_count += 1 | |
| db.commit() | |
| # Log error | |
| error_log = UploadErrorLog( | |
| file_name=file_info["file_name"], | |
| error_code=str(result.get("status_code")), | |
| error_message=result["message"], | |
| status_code=result.get("status_code"), | |
| retryable=result.get("retryable", True), | |
| ) | |
| db.add(error_log) | |
| db.commit() | |
| # Rate limiting delay between uploads | |
| await asyncio.sleep(0.5) | |
| return { | |
| "successful": successful, | |
| "failed": failed, | |
| "paused": paused, | |
| "results": results, | |
| } | |
| async def get_upload_status(self, db: Session) -> Dict: | |
| """ | |
| Get current upload status | |
| Args: | |
| db: Database session | |
| Returns: | |
| Status dictionary | |
| """ | |
| rate_check = await self.rate_limiter.check_rate_limit(db) | |
| config = db.query(HFConfig).first() | |
| return { | |
| "rate_limit": rate_check, | |
| "config": { | |
| "max_uploads_per_hour": config.max_uploads_per_hour | |
| if config | |
| else 128, | |
| "upload_batch_size": config.upload_batch_size if config else 10, | |
| "target_repo": config.target_repo if config else self.target_repo, | |
| } | |
| if config | |
| else None, | |
| } | |