jorhia / app.py
samfred2's picture
Update app.py
d8853ef verified
import os
import json
import asyncio
import logging
import requests
from datetime import datetime
from typing import Optional
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from huggingface_hub import HfApi, hf_hub_url
from tqdm import tqdm
# --- Configuration ---
SOURCE_REPO_ID = "Fred808/TGFiles"
TARGET_REPO_ID = "samfred2/TGFiles"
REPO_TYPE = "dataset"
REVISION = "main"
STATE_FILE = "sync_state.json"
TOKEN = os.environ.get("HF_TOKEN", "")
# Whether the service should automatically start syncing on FastAPI startup.
# Set AUTO_START env var to "0", "false" or "no" to disable.
AUTO_START = os.environ.get("AUTO_START", "true").lower() in ("1", "true", "yes")
# How long (seconds) before a remote "running" state is considered stale and ignored.
# Default: 1 hour. Set STATE_STALE_SECONDS env var to override.
try:
STATE_STALE_SECONDS = int(os.environ.get("STATE_STALE_SECONDS", "3600"))
except Exception:
STATE_STALE_SECONDS = 3600
# --- FastAPI App ---
app = FastAPI(title="HF Dataset Sync Service")
# Configure basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Auto-start handler ---
@app.on_event("startup")
async def _maybe_start_sync_on_startup():
"""If AUTO_START is enabled, start synchronization in the background on app startup.
This will not start a new sync if the saved state already shows a running sync.
"""
try:
if not AUTO_START:
logger.info("AUTO_START disabled; not starting sync on startup")
return
# Attempt to download remote state first so we resume from the last-known state
try:
download_remote_state(TOKEN)
except Exception:
logger.exception("Failed to download remote state on startup; continuing with local state if present")
# If the downloaded/loaded state reports "running", it may be stale
# (service crashed or was stopped). In that case we consider it stale
# if last_updated is older than STATE_STALE_SECONDS and clear it so the
# sync can resume.
state = load_state()
if state.get("status") == "running":
last = state.get("last_updated")
stale = False
if last:
try:
last_dt = datetime.fromisoformat(last)
age = (datetime.utcnow() - last_dt).total_seconds()
if age > STATE_STALE_SECONDS:
stale = True
logger.info("Remote state marked running but last_updated is %.0f seconds old (>%s); treating as stale", age, STATE_STALE_SECONDS)
except Exception:
# If we can't parse last_updated, treat as stale to be safe
stale = True
logger.exception("Failed to parse last_updated from remote state; treating running state as stale")
if stale:
state["status"] = "idle"
state["current_file"] = None
# upload the corrected state back to the remote repo so other instances see it
try:
save_state(state, token=TOKEN)
except Exception:
logger.exception("Failed to upload corrected stale state; continuing with local corrected state")
else:
# Remote state indicates a sync was running recently and is not stale.
# We assume there's no other live process and resume that sync here.
logger.info("Remote state indicates an in-progress sync; resuming from that state")
# fall through to starting the sync below
# Launch the async synchronization as a background task without blocking startup
logger.info("AUTO_START enabled; launching dataset synchronization in background")
asyncio.create_task(synchronize_datasets(TOKEN))
except Exception as e:
logger.exception("Failed to auto-start synchronization: %s", e)
def download_remote_state(token: Optional[str] = None):
"""Attempt to download `STATE_FILE` from the target repo and overwrite local state.
If no remote state exists (404) this function returns quietly. Any other
exceptions are raised for the caller to handle/log.
"""
try:
download_url = hf_hub_url(
repo_id=TARGET_REPO_ID,
filename=STATE_FILE,
repo_type=REPO_TYPE,
revision=REVISION,
)
resp = requests.get(download_url)
if resp.status_code == 200:
with open(STATE_FILE, "wb") as f:
f.write(resp.content)
logger.info("Remote state downloaded from %s/%s", TARGET_REPO_ID, STATE_FILE)
else:
logger.info("Remote state not found (status %s) — continuing with local state if present", resp.status_code)
except Exception:
logger.exception("Failed to download remote state file")
raise
# --- Models ---
class SyncStatus(BaseModel):
status: str
total_files: int
synced_files: int
failed_files: int
current_file: Optional[str] = None
progress_percent: float
last_updated: str
class SyncRequest(BaseModel):
token: Optional[str] = None
# --- State Management ---
def load_state():
"""Loads the synchronization state from a JSON file."""
if os.path.exists(STATE_FILE):
with open(STATE_FILE, "r") as f:
return json.load(f)
return {
"synced_files": [],
"failed_files": [],
"total_files": 0,
"status": "idle",
"last_updated": datetime.utcnow().isoformat(),
"current_file": None
}
def save_state(state, api: Optional[HfApi] = None, token: Optional[str] = None):
"""Save the state locally and optionally upload it to the target repo.
If `api` is provided it will be used to upload the `STATE_FILE` to the
configured `TARGET_REPO_ID`. If `api` is None but `token` is provided,
a temporary HfApi will be created for upload.
"""
state["last_updated"] = datetime.utcnow().isoformat()
with open(STATE_FILE, "w") as f:
json.dump(state, f, indent=4)
# Try to upload the state file to the target repo so restarts can resume
try:
upload_api = api
if upload_api is None and token:
upload_api = HfApi(token=token)
if upload_api is not None:
# upload the local state file to the target dataset
try:
upload_api.upload_file(
path_or_fileobj=STATE_FILE,
path_in_repo=STATE_FILE,
repo_id=TARGET_REPO_ID,
repo_type=REPO_TYPE,
commit_message=f"Sync: Update {STATE_FILE}",
)
logger.info("Uploaded state file to %s/%s", TARGET_REPO_ID, STATE_FILE)
except Exception:
# Log and continue — failing to upload state should not crash the service
logger.exception("Failed to upload state file to remote repo")
except Exception:
# Catch-all for unexpected errors constructing HfApi
logger.exception("Unexpected error while attempting to save/upload state")
# Helper to perform blocking download/upload in a thread
def _download_and_upload_file(api: HfApi, file_path: str):
"""Download a file from the source repo and upload it to the target repo.
This function is synchronous and intended to be run with asyncio.to_thread
so that blocking I/O doesn't block the event loop.
"""
# Build download URL and local filename
download_url = hf_hub_url(
repo_id=SOURCE_REPO_ID,
filename=file_path,
repo_type=REPO_TYPE,
revision=REVISION
)
local_path = os.path.basename(file_path)
# Download using requests (blocking)
response = requests.get(download_url, stream=True)
response.raise_for_status()
with open(local_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
# Upload file to target repo
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=file_path,
repo_id=TARGET_REPO_ID,
repo_type=REPO_TYPE,
commit_message=f"Sync: Add {file_path} from {SOURCE_REPO_ID}",
)
# Clean up local file
try:
os.remove(local_path)
except Exception:
# If cleanup fails, just continue; file may be left behind for inspection
logger.exception("Failed to remove local file %s", local_path)
# --- Synchronization Logic ---
async def synchronize_datasets(token: str):
"""
Fetches file list from source, downloads files, and uploads them to target,
persisting state to resume progress.
"""
state = load_state()
# Initialize HF API first so we can persist state remotely as we progress.
try:
api = HfApi(token=token)
except Exception as e:
state["status"] = "error"
state["error"] = f"Error initializing HfApi: {str(e)}"
save_state(state)
return
state["status"] = "running"
save_state(state, api=api)
synced_files = set(state["synced_files"])
failed_files = set(state.get("failed_files", []))
try:
repo_files = api.list_repo_files(
repo_id=SOURCE_REPO_ID,
repo_type=REPO_TYPE,
revision=REVISION
)
except Exception as e:
state["status"] = "error"
state["error"] = f"Error fetching file list: {str(e)}"
save_state(state)
return
state["total_files"] = len(repo_files)
files_to_sync = [f for f in repo_files if f not in synced_files and f not in failed_files]
for idx, file_path in enumerate(files_to_sync):
if file_path in synced_files:
continue
state["current_file"] = file_path
state["synced_files_count"] = len(synced_files)
state["progress_percent"] = (len(synced_files) / state["total_files"]) * 100 if state["total_files"] > 0 else 0
save_state(state, api=api)
# Skip .gitattributes
if file_path == ".gitattributes":
synced_files.add(file_path)
continue
try:
# Perform blocking download+upload in a thread to avoid blocking the event loop
await asyncio.to_thread(_download_and_upload_file, api, file_path)
# Mark as synced and persist (locally + remote)
synced_files.add(file_path)
state["synced_files"] = list(synced_files)
save_state(state, api=api)
except Exception as e:
logger.exception("Error syncing file %s: %s", file_path, e)
failed_files.add(file_path)
state["failed_files"] = list(failed_files)
save_state(state, api=api)
# Wait 2 minutes between processing files to throttle downloads/uploads
# Skip waiting after the last file
if idx != len(files_to_sync) - 1:
state["status"] = "running"
save_state(state, api=api)
logger.info("Waiting 120 seconds before processing next file")
await asyncio.sleep(140)
state["status"] = "completed"
state["current_file"] = None
state["synced_files_count"] = len(synced_files)
state["progress_percent"] = 100.0
save_state(state, api=api)
# --- Endpoints ---
@app.get("/")
async def root():
"""Health check endpoint."""
return {"status": "ok", "service": "HF Dataset Sync Service"}
@app.get("/status", response_model=SyncStatus)
async def get_status():
"""Get current synchronization status."""
state = load_state()
return SyncStatus(
status=state.get("status", "idle"),
total_files=state.get("total_files", 0),
synced_files=len(state.get("synced_files", [])),
failed_files=len(state.get("failed_files", [])),
current_file=state.get("current_file"),
progress_percent=state.get("progress_percent", 0.0),
last_updated=state.get("last_updated", datetime.utcnow().isoformat())
)
@app.post("/sync")
async def start_sync(request: SyncRequest, background_tasks: BackgroundTasks):
"""Start the synchronization process."""
state = load_state()
if state.get("status") == "running":
raise HTTPException(status_code=409, detail="Sync is already running")
token = request.token or TOKEN
background_tasks.add_task(synchronize_datasets, token)
return {"message": "Sync started", "status": "running"}
@app.post("/reset")
async def reset_state():
"""Reset the synchronization state."""
state = {
"synced_files": [],
"failed_files": [],
"total_files": 0,
"status": "idle",
"last_updated": datetime.utcnow().isoformat(),
"current_file": None
}
save_state(state)
return {"message": "State reset", "status": "idle"}
@app.get("/state")
async def get_full_state():
"""Get the full synchronization state."""
return load_state()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)