|
|
import os |
|
|
import json |
|
|
import asyncio |
|
|
import logging |
|
|
import requests |
|
|
from datetime import datetime |
|
|
from typing import Optional |
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
|
from fastapi.responses import JSONResponse |
|
|
from pydantic import BaseModel |
|
|
from huggingface_hub import HfApi, hf_hub_url |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
SOURCE_REPO_ID = "Fred808/TGFiles" |
|
|
TARGET_REPO_ID = "samfred2/TGFiles" |
|
|
REPO_TYPE = "dataset" |
|
|
REVISION = "main" |
|
|
STATE_FILE = "sync_state.json" |
|
|
TOKEN = os.environ.get("HF_TOKEN", "") |
|
|
|
|
|
|
|
|
AUTO_START = os.environ.get("AUTO_START", "true").lower() in ("1", "true", "yes") |
|
|
|
|
|
|
|
|
try: |
|
|
STATE_STALE_SECONDS = int(os.environ.get("STATE_STALE_SECONDS", "3600")) |
|
|
except Exception: |
|
|
STATE_STALE_SECONDS = 3600 |
|
|
|
|
|
|
|
|
app = FastAPI(title="HF Dataset Sync Service") |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def _maybe_start_sync_on_startup(): |
|
|
"""If AUTO_START is enabled, start synchronization in the background on app startup. |
|
|
|
|
|
This will not start a new sync if the saved state already shows a running sync. |
|
|
""" |
|
|
try: |
|
|
if not AUTO_START: |
|
|
logger.info("AUTO_START disabled; not starting sync on startup") |
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
download_remote_state(TOKEN) |
|
|
except Exception: |
|
|
logger.exception("Failed to download remote state on startup; continuing with local state if present") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
state = load_state() |
|
|
if state.get("status") == "running": |
|
|
last = state.get("last_updated") |
|
|
stale = False |
|
|
if last: |
|
|
try: |
|
|
last_dt = datetime.fromisoformat(last) |
|
|
age = (datetime.utcnow() - last_dt).total_seconds() |
|
|
if age > STATE_STALE_SECONDS: |
|
|
stale = True |
|
|
logger.info("Remote state marked running but last_updated is %.0f seconds old (>%s); treating as stale", age, STATE_STALE_SECONDS) |
|
|
except Exception: |
|
|
|
|
|
stale = True |
|
|
logger.exception("Failed to parse last_updated from remote state; treating running state as stale") |
|
|
|
|
|
if stale: |
|
|
state["status"] = "idle" |
|
|
state["current_file"] = None |
|
|
|
|
|
try: |
|
|
save_state(state, token=TOKEN) |
|
|
except Exception: |
|
|
logger.exception("Failed to upload corrected stale state; continuing with local corrected state") |
|
|
else: |
|
|
|
|
|
|
|
|
logger.info("Remote state indicates an in-progress sync; resuming from that state") |
|
|
|
|
|
|
|
|
|
|
|
logger.info("AUTO_START enabled; launching dataset synchronization in background") |
|
|
asyncio.create_task(synchronize_datasets(TOKEN)) |
|
|
except Exception as e: |
|
|
logger.exception("Failed to auto-start synchronization: %s", e) |
|
|
|
|
|
|
|
|
def download_remote_state(token: Optional[str] = None): |
|
|
"""Attempt to download `STATE_FILE` from the target repo and overwrite local state. |
|
|
|
|
|
If no remote state exists (404) this function returns quietly. Any other |
|
|
exceptions are raised for the caller to handle/log. |
|
|
""" |
|
|
try: |
|
|
download_url = hf_hub_url( |
|
|
repo_id=TARGET_REPO_ID, |
|
|
filename=STATE_FILE, |
|
|
repo_type=REPO_TYPE, |
|
|
revision=REVISION, |
|
|
) |
|
|
resp = requests.get(download_url) |
|
|
if resp.status_code == 200: |
|
|
with open(STATE_FILE, "wb") as f: |
|
|
f.write(resp.content) |
|
|
logger.info("Remote state downloaded from %s/%s", TARGET_REPO_ID, STATE_FILE) |
|
|
else: |
|
|
logger.info("Remote state not found (status %s) — continuing with local state if present", resp.status_code) |
|
|
except Exception: |
|
|
logger.exception("Failed to download remote state file") |
|
|
raise |
|
|
|
|
|
|
|
|
class SyncStatus(BaseModel): |
|
|
status: str |
|
|
total_files: int |
|
|
synced_files: int |
|
|
failed_files: int |
|
|
current_file: Optional[str] = None |
|
|
progress_percent: float |
|
|
last_updated: str |
|
|
|
|
|
class SyncRequest(BaseModel): |
|
|
token: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
def load_state(): |
|
|
"""Loads the synchronization state from a JSON file.""" |
|
|
if os.path.exists(STATE_FILE): |
|
|
with open(STATE_FILE, "r") as f: |
|
|
return json.load(f) |
|
|
return { |
|
|
"synced_files": [], |
|
|
"failed_files": [], |
|
|
"total_files": 0, |
|
|
"status": "idle", |
|
|
"last_updated": datetime.utcnow().isoformat(), |
|
|
"current_file": None |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def save_state(state, api: Optional[HfApi] = None, token: Optional[str] = None): |
|
|
"""Save the state locally and optionally upload it to the target repo. |
|
|
|
|
|
If `api` is provided it will be used to upload the `STATE_FILE` to the |
|
|
configured `TARGET_REPO_ID`. If `api` is None but `token` is provided, |
|
|
a temporary HfApi will be created for upload. |
|
|
""" |
|
|
state["last_updated"] = datetime.utcnow().isoformat() |
|
|
with open(STATE_FILE, "w") as f: |
|
|
json.dump(state, f, indent=4) |
|
|
|
|
|
|
|
|
try: |
|
|
upload_api = api |
|
|
if upload_api is None and token: |
|
|
upload_api = HfApi(token=token) |
|
|
|
|
|
if upload_api is not None: |
|
|
|
|
|
try: |
|
|
upload_api.upload_file( |
|
|
path_or_fileobj=STATE_FILE, |
|
|
path_in_repo=STATE_FILE, |
|
|
repo_id=TARGET_REPO_ID, |
|
|
repo_type=REPO_TYPE, |
|
|
commit_message=f"Sync: Update {STATE_FILE}", |
|
|
) |
|
|
logger.info("Uploaded state file to %s/%s", TARGET_REPO_ID, STATE_FILE) |
|
|
except Exception: |
|
|
|
|
|
logger.exception("Failed to upload state file to remote repo") |
|
|
except Exception: |
|
|
|
|
|
logger.exception("Unexpected error while attempting to save/upload state") |
|
|
|
|
|
|
|
|
|
|
|
def _download_and_upload_file(api: HfApi, file_path: str): |
|
|
"""Download a file from the source repo and upload it to the target repo. |
|
|
|
|
|
This function is synchronous and intended to be run with asyncio.to_thread |
|
|
so that blocking I/O doesn't block the event loop. |
|
|
""" |
|
|
|
|
|
download_url = hf_hub_url( |
|
|
repo_id=SOURCE_REPO_ID, |
|
|
filename=file_path, |
|
|
repo_type=REPO_TYPE, |
|
|
revision=REVISION |
|
|
) |
|
|
local_path = os.path.basename(file_path) |
|
|
|
|
|
|
|
|
response = requests.get(download_url, stream=True) |
|
|
response.raise_for_status() |
|
|
|
|
|
with open(local_path, "wb") as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
if chunk: |
|
|
f.write(chunk) |
|
|
|
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=local_path, |
|
|
path_in_repo=file_path, |
|
|
repo_id=TARGET_REPO_ID, |
|
|
repo_type=REPO_TYPE, |
|
|
commit_message=f"Sync: Add {file_path} from {SOURCE_REPO_ID}", |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
os.remove(local_path) |
|
|
except Exception: |
|
|
|
|
|
logger.exception("Failed to remove local file %s", local_path) |
|
|
|
|
|
|
|
|
|
|
|
async def synchronize_datasets(token: str): |
|
|
""" |
|
|
Fetches file list from source, downloads files, and uploads them to target, |
|
|
persisting state to resume progress. |
|
|
""" |
|
|
state = load_state() |
|
|
|
|
|
|
|
|
try: |
|
|
api = HfApi(token=token) |
|
|
except Exception as e: |
|
|
state["status"] = "error" |
|
|
state["error"] = f"Error initializing HfApi: {str(e)}" |
|
|
save_state(state) |
|
|
return |
|
|
|
|
|
state["status"] = "running" |
|
|
save_state(state, api=api) |
|
|
|
|
|
synced_files = set(state["synced_files"]) |
|
|
failed_files = set(state.get("failed_files", [])) |
|
|
|
|
|
try: |
|
|
repo_files = api.list_repo_files( |
|
|
repo_id=SOURCE_REPO_ID, |
|
|
repo_type=REPO_TYPE, |
|
|
revision=REVISION |
|
|
) |
|
|
except Exception as e: |
|
|
state["status"] = "error" |
|
|
state["error"] = f"Error fetching file list: {str(e)}" |
|
|
save_state(state) |
|
|
return |
|
|
|
|
|
state["total_files"] = len(repo_files) |
|
|
files_to_sync = [f for f in repo_files if f not in synced_files and f not in failed_files] |
|
|
|
|
|
for idx, file_path in enumerate(files_to_sync): |
|
|
if file_path in synced_files: |
|
|
continue |
|
|
|
|
|
state["current_file"] = file_path |
|
|
state["synced_files_count"] = len(synced_files) |
|
|
state["progress_percent"] = (len(synced_files) / state["total_files"]) * 100 if state["total_files"] > 0 else 0 |
|
|
save_state(state, api=api) |
|
|
|
|
|
|
|
|
if file_path == ".gitattributes": |
|
|
synced_files.add(file_path) |
|
|
continue |
|
|
|
|
|
try: |
|
|
|
|
|
await asyncio.to_thread(_download_and_upload_file, api, file_path) |
|
|
|
|
|
|
|
|
synced_files.add(file_path) |
|
|
state["synced_files"] = list(synced_files) |
|
|
save_state(state, api=api) |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Error syncing file %s: %s", file_path, e) |
|
|
failed_files.add(file_path) |
|
|
state["failed_files"] = list(failed_files) |
|
|
save_state(state, api=api) |
|
|
|
|
|
|
|
|
|
|
|
if idx != len(files_to_sync) - 1: |
|
|
state["status"] = "running" |
|
|
save_state(state, api=api) |
|
|
logger.info("Waiting 120 seconds before processing next file") |
|
|
await asyncio.sleep(140) |
|
|
|
|
|
state["status"] = "completed" |
|
|
state["current_file"] = None |
|
|
state["synced_files_count"] = len(synced_files) |
|
|
state["progress_percent"] = 100.0 |
|
|
save_state(state, api=api) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Health check endpoint.""" |
|
|
return {"status": "ok", "service": "HF Dataset Sync Service"} |
|
|
|
|
|
@app.get("/status", response_model=SyncStatus) |
|
|
async def get_status(): |
|
|
"""Get current synchronization status.""" |
|
|
state = load_state() |
|
|
return SyncStatus( |
|
|
status=state.get("status", "idle"), |
|
|
total_files=state.get("total_files", 0), |
|
|
synced_files=len(state.get("synced_files", [])), |
|
|
failed_files=len(state.get("failed_files", [])), |
|
|
current_file=state.get("current_file"), |
|
|
progress_percent=state.get("progress_percent", 0.0), |
|
|
last_updated=state.get("last_updated", datetime.utcnow().isoformat()) |
|
|
) |
|
|
|
|
|
@app.post("/sync") |
|
|
async def start_sync(request: SyncRequest, background_tasks: BackgroundTasks): |
|
|
"""Start the synchronization process.""" |
|
|
state = load_state() |
|
|
|
|
|
if state.get("status") == "running": |
|
|
raise HTTPException(status_code=409, detail="Sync is already running") |
|
|
|
|
|
token = request.token or TOKEN |
|
|
background_tasks.add_task(synchronize_datasets, token) |
|
|
|
|
|
return {"message": "Sync started", "status": "running"} |
|
|
|
|
|
@app.post("/reset") |
|
|
async def reset_state(): |
|
|
"""Reset the synchronization state.""" |
|
|
state = { |
|
|
"synced_files": [], |
|
|
"failed_files": [], |
|
|
"total_files": 0, |
|
|
"status": "idle", |
|
|
"last_updated": datetime.utcnow().isoformat(), |
|
|
"current_file": None |
|
|
} |
|
|
save_state(state) |
|
|
return {"message": "State reset", "status": "idle"} |
|
|
|
|
|
@app.get("/state") |
|
|
async def get_full_state(): |
|
|
"""Get the full synchronization state.""" |
|
|
return load_state() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|