import os import json import asyncio import logging import requests from datetime import datetime from typing import Optional from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.responses import JSONResponse from pydantic import BaseModel from huggingface_hub import HfApi, hf_hub_url from tqdm import tqdm # --- Configuration --- SOURCE_REPO_ID = "Fred808/TGFiles" TARGET_REPO_ID = "samfred2/TGFiles" REPO_TYPE = "dataset" REVISION = "main" STATE_FILE = "sync_state.json" TOKEN = os.environ.get("HF_TOKEN", "") # Whether the service should automatically start syncing on FastAPI startup. # Set AUTO_START env var to "0", "false" or "no" to disable. AUTO_START = os.environ.get("AUTO_START", "true").lower() in ("1", "true", "yes") # How long (seconds) before a remote "running" state is considered stale and ignored. # Default: 1 hour. Set STATE_STALE_SECONDS env var to override. try: STATE_STALE_SECONDS = int(os.environ.get("STATE_STALE_SECONDS", "3600")) except Exception: STATE_STALE_SECONDS = 3600 # --- FastAPI App --- app = FastAPI(title="HF Dataset Sync Service") # Configure basic logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- Auto-start handler --- @app.on_event("startup") async def _maybe_start_sync_on_startup(): """If AUTO_START is enabled, start synchronization in the background on app startup. This will not start a new sync if the saved state already shows a running sync. """ try: if not AUTO_START: logger.info("AUTO_START disabled; not starting sync on startup") return # Attempt to download remote state first so we resume from the last-known state try: download_remote_state(TOKEN) except Exception: logger.exception("Failed to download remote state on startup; continuing with local state if present") # If the downloaded/loaded state reports "running", it may be stale # (service crashed or was stopped). In that case we consider it stale # if last_updated is older than STATE_STALE_SECONDS and clear it so the # sync can resume. state = load_state() if state.get("status") == "running": last = state.get("last_updated") stale = False if last: try: last_dt = datetime.fromisoformat(last) age = (datetime.utcnow() - last_dt).total_seconds() if age > STATE_STALE_SECONDS: stale = True logger.info("Remote state marked running but last_updated is %.0f seconds old (>%s); treating as stale", age, STATE_STALE_SECONDS) except Exception: # If we can't parse last_updated, treat as stale to be safe stale = True logger.exception("Failed to parse last_updated from remote state; treating running state as stale") if stale: state["status"] = "idle" state["current_file"] = None # upload the corrected state back to the remote repo so other instances see it try: save_state(state, token=TOKEN) except Exception: logger.exception("Failed to upload corrected stale state; continuing with local corrected state") else: # Remote state indicates a sync was running recently and is not stale. # We assume there's no other live process and resume that sync here. logger.info("Remote state indicates an in-progress sync; resuming from that state") # fall through to starting the sync below # Launch the async synchronization as a background task without blocking startup logger.info("AUTO_START enabled; launching dataset synchronization in background") asyncio.create_task(synchronize_datasets(TOKEN)) except Exception as e: logger.exception("Failed to auto-start synchronization: %s", e) def download_remote_state(token: Optional[str] = None): """Attempt to download `STATE_FILE` from the target repo and overwrite local state. If no remote state exists (404) this function returns quietly. Any other exceptions are raised for the caller to handle/log. """ try: download_url = hf_hub_url( repo_id=TARGET_REPO_ID, filename=STATE_FILE, repo_type=REPO_TYPE, revision=REVISION, ) resp = requests.get(download_url) if resp.status_code == 200: with open(STATE_FILE, "wb") as f: f.write(resp.content) logger.info("Remote state downloaded from %s/%s", TARGET_REPO_ID, STATE_FILE) else: logger.info("Remote state not found (status %s) — continuing with local state if present", resp.status_code) except Exception: logger.exception("Failed to download remote state file") raise # --- Models --- class SyncStatus(BaseModel): status: str total_files: int synced_files: int failed_files: int current_file: Optional[str] = None progress_percent: float last_updated: str class SyncRequest(BaseModel): token: Optional[str] = None # --- State Management --- def load_state(): """Loads the synchronization state from a JSON file.""" if os.path.exists(STATE_FILE): with open(STATE_FILE, "r") as f: return json.load(f) return { "synced_files": [], "failed_files": [], "total_files": 0, "status": "idle", "last_updated": datetime.utcnow().isoformat(), "current_file": None } def save_state(state, api: Optional[HfApi] = None, token: Optional[str] = None): """Save the state locally and optionally upload it to the target repo. If `api` is provided it will be used to upload the `STATE_FILE` to the configured `TARGET_REPO_ID`. If `api` is None but `token` is provided, a temporary HfApi will be created for upload. """ state["last_updated"] = datetime.utcnow().isoformat() with open(STATE_FILE, "w") as f: json.dump(state, f, indent=4) # Try to upload the state file to the target repo so restarts can resume try: upload_api = api if upload_api is None and token: upload_api = HfApi(token=token) if upload_api is not None: # upload the local state file to the target dataset try: upload_api.upload_file( path_or_fileobj=STATE_FILE, path_in_repo=STATE_FILE, repo_id=TARGET_REPO_ID, repo_type=REPO_TYPE, commit_message=f"Sync: Update {STATE_FILE}", ) logger.info("Uploaded state file to %s/%s", TARGET_REPO_ID, STATE_FILE) except Exception: # Log and continue — failing to upload state should not crash the service logger.exception("Failed to upload state file to remote repo") except Exception: # Catch-all for unexpected errors constructing HfApi logger.exception("Unexpected error while attempting to save/upload state") # Helper to perform blocking download/upload in a thread def _download_and_upload_file(api: HfApi, file_path: str): """Download a file from the source repo and upload it to the target repo. This function is synchronous and intended to be run with asyncio.to_thread so that blocking I/O doesn't block the event loop. """ # Build download URL and local filename download_url = hf_hub_url( repo_id=SOURCE_REPO_ID, filename=file_path, repo_type=REPO_TYPE, revision=REVISION ) local_path = os.path.basename(file_path) # Download using requests (blocking) response = requests.get(download_url, stream=True) response.raise_for_status() with open(local_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) # Upload file to target repo api.upload_file( path_or_fileobj=local_path, path_in_repo=file_path, repo_id=TARGET_REPO_ID, repo_type=REPO_TYPE, commit_message=f"Sync: Add {file_path} from {SOURCE_REPO_ID}", ) # Clean up local file try: os.remove(local_path) except Exception: # If cleanup fails, just continue; file may be left behind for inspection logger.exception("Failed to remove local file %s", local_path) # --- Synchronization Logic --- async def synchronize_datasets(token: str): """ Fetches file list from source, downloads files, and uploads them to target, persisting state to resume progress. """ state = load_state() # Initialize HF API first so we can persist state remotely as we progress. try: api = HfApi(token=token) except Exception as e: state["status"] = "error" state["error"] = f"Error initializing HfApi: {str(e)}" save_state(state) return state["status"] = "running" save_state(state, api=api) synced_files = set(state["synced_files"]) failed_files = set(state.get("failed_files", [])) try: repo_files = api.list_repo_files( repo_id=SOURCE_REPO_ID, repo_type=REPO_TYPE, revision=REVISION ) except Exception as e: state["status"] = "error" state["error"] = f"Error fetching file list: {str(e)}" save_state(state) return state["total_files"] = len(repo_files) files_to_sync = [f for f in repo_files if f not in synced_files and f not in failed_files] for idx, file_path in enumerate(files_to_sync): if file_path in synced_files: continue state["current_file"] = file_path state["synced_files_count"] = len(synced_files) state["progress_percent"] = (len(synced_files) / state["total_files"]) * 100 if state["total_files"] > 0 else 0 save_state(state, api=api) # Skip .gitattributes if file_path == ".gitattributes": synced_files.add(file_path) continue try: # Perform blocking download+upload in a thread to avoid blocking the event loop await asyncio.to_thread(_download_and_upload_file, api, file_path) # Mark as synced and persist (locally + remote) synced_files.add(file_path) state["synced_files"] = list(synced_files) save_state(state, api=api) except Exception as e: logger.exception("Error syncing file %s: %s", file_path, e) failed_files.add(file_path) state["failed_files"] = list(failed_files) save_state(state, api=api) # Wait 2 minutes between processing files to throttle downloads/uploads # Skip waiting after the last file if idx != len(files_to_sync) - 1: state["status"] = "running" save_state(state, api=api) logger.info("Waiting 120 seconds before processing next file") await asyncio.sleep(140) state["status"] = "completed" state["current_file"] = None state["synced_files_count"] = len(synced_files) state["progress_percent"] = 100.0 save_state(state, api=api) # --- Endpoints --- @app.get("/") async def root(): """Health check endpoint.""" return {"status": "ok", "service": "HF Dataset Sync Service"} @app.get("/status", response_model=SyncStatus) async def get_status(): """Get current synchronization status.""" state = load_state() return SyncStatus( status=state.get("status", "idle"), total_files=state.get("total_files", 0), synced_files=len(state.get("synced_files", [])), failed_files=len(state.get("failed_files", [])), current_file=state.get("current_file"), progress_percent=state.get("progress_percent", 0.0), last_updated=state.get("last_updated", datetime.utcnow().isoformat()) ) @app.post("/sync") async def start_sync(request: SyncRequest, background_tasks: BackgroundTasks): """Start the synchronization process.""" state = load_state() if state.get("status") == "running": raise HTTPException(status_code=409, detail="Sync is already running") token = request.token or TOKEN background_tasks.add_task(synchronize_datasets, token) return {"message": "Sync started", "status": "running"} @app.post("/reset") async def reset_state(): """Reset the synchronization state.""" state = { "synced_files": [], "failed_files": [], "total_files": 0, "status": "idle", "last_updated": datetime.utcnow().isoformat(), "current_file": None } save_state(state) return {"message": "State reset", "status": "idle"} @app.get("/state") async def get_full_state(): """Get the full synchronization state.""" return load_state() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)