#!/usr/bin/env python3 """WebDAV server backed by Hugging Face Hub (dataset storage). Windows-compatible: includes all required DAV headers, proper PROPFIND responses, parent-directory auto-creation, and safe delete logic. """ import os import io import logging import stat from datetime import datetime, timezone from typing import Optional from wsgidav.wsgidav_app import WsgiDAVApp from wsgidav.dav_provider import DAVProvider, DAVCollection, DAVNonCollection from huggingface_hub import HfFileSystem, hf_hub_download logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ── Configuration ────────────────────────────────────────────── HF_REPO_ID = os.environ.get("HF_REPO_ID", "e2dew32/cloud-drive") HF_REPO_TYPE = "dataset" fs = HfFileSystem() REPO_ROOT = f"{HF_REPO_TYPE}s/{HF_REPO_ID}" def _norm(path: str) -> str: """Normalise a WebDAV path to an HF path.""" path = path.strip("/") return f"{REPO_ROOT}/{path}" if path else REPO_ROOT def _stat(path: str) -> dict: """Return info dict for a file or directory.""" try: info = fs.info(path) return info except Exception: return {} def _ensure_parent(hf_path: str): """Create parent directories on HF Hub if they don't exist.""" parent = "/".join(hf_path.split("/")[:-1]) if not parent or parent == REPO_ROOT: return try: if not fs.isdir(parent): fs.mkdir(parent, exist_ok=True, parents=True) logger.info(f"Created parent dir: {parent}") except Exception as e: logger.warning(f"ensure_parent failed for {parent}: {e}") # ── DAV Provider ─────────────────────────────────────────────── class HFDAVProvider(DAVProvider): """Expose HF Hub storage as a WebDAV tree.""" def __init__(self): super().__init__() def get_resource_inst(self, path, environ): norm = _norm(path) info = _stat(norm) if not info: return None if info.get("type") == "directory": return HFDavCollection(path, environ, info) return HFDavNonCollection(path, environ, info) def is_collection(self, path): info = _stat(_norm(path)) return info.get("type") == "directory" class HFDavCollection(DAVCollection): def __init__(self, path, environ, info): super().__init__(path, environ) self._info = info def get_display_name(self): return self.path.rstrip("/").split("/")[-1] or "/" def get_member_names(self): norm = _norm(self.path) try: entries = fs.ls(norm, detail=False) except Exception: return [] names = [] for e in entries: name = e.rsplit("/", 1)[-1] # Filter out . and .. to prevent Windows infinite loops if name not in (".", "..", ""): names.append(name) return names def get_member(self, name): child_path = self.path.rstrip("/") + "/" + name return self.provider.get_resource_inst(child_path, self.environ) def create_empty_resource(self, name): """Called when client does PUT with no body first.""" child_path = self.path.rstrip("/") + "/" + name norm = _norm(child_path) _ensure_parent(norm) return HFDavNonCollection( child_path, self.environ, {"type": "file", "size": 0, "name": name} ) def create_collection(self, name): """MKCOL – create a directory.""" new_path = self.path.rstrip("/") + "/" + name norm = _norm(new_path) try: fs.mkdir(norm, exist_ok=True, parents=True) logger.info(f"MKCOL: {norm}") except Exception as e: logger.error(f"mkdir failed: {e}") def get_creation_date(self): return datetime.now(timezone.utc) def get_modified_date(self): return datetime.now(timezone.utc) def get_content_length(self): return 0 def get_content_type(self): return "httpd/unix-directory" def is_collection(self): return True def support_etag(self): return 0 def get_etag(self): return None def is_property_locked(self, name): return False def get_property_names(self, is_allprop): return [] def get_property_value(self, name): return None def support_recursive_delete(self, path): return True def delete(self): norm = _norm(self.path) # Safety: never delete the repo root if norm == REPO_ROOT: logger.warning("Blocked attempt to delete repo root") return try: entries = fs.ls(norm, detail=False) for e in entries: try: if fs.isdir(e): fs.rm(e, recursive=True) else: fs.rm(e) except Exception as inner: logger.warning(f"delete entry {e}: {inner}") fs.rmdir(norm) logger.info(f"Deleted collection: {norm}") except Exception as e: logger.error(f"delete collection failed: {e}") class HFDavNonCollection(DAVNonCollection): def __init__(self, path, environ, info): super().__init__(path, environ) self._info = info def get_display_name(self): return self.path.rstrip("/").split("/")[-1] def get_content_length(self): return self._info.get("size", 0) def get_content_type(self): name = self.get_display_name() if "." in name: ext = name.rsplit(".", 1)[-1].lower() types = { "txt": "text/plain", "md": "text/markdown", "json": "application/json", "html": "text/html", "css": "text/css", "js": "application/javascript", "xml": "application/xml", "csv": "text/csv", "jpg": "image/jpeg", "jpeg": "image/jpeg", "png": "image/png", "gif": "image/gif", "svg": "image/svg+xml", "pdf": "application/pdf", "zip": "application/zip", "mp3": "audio/mpeg", "mp4": "video/mp4", } return types.get(ext, "application/octet-stream") return "application/octet-stream" def get_creation_date(self): return datetime.now(timezone.utc) def get_modified_date(self): lc = self._info.get("last_commit") or {} mtime = lc.get("date") if mtime: try: return datetime.fromisoformat(mtime.replace("Z", "+00:00")) except Exception: pass return datetime.now(timezone.utc) def get_content(self): norm = _norm(self.path) try: with fs.open(norm, "rb") as f: return io.BytesIO(f.read()) except Exception as e: logger.error(f"read failed: {e}") return io.BytesIO(b"") def begin_write(self, content_type=None): norm = _norm(self.path) _ensure_parent(norm) logger.info(f"begin_write: {norm}") return HFWriteBuffer(norm) def is_property_locked(self, name): return False def get_property_names(self, is_allprop): return [] def get_property_value(self, name): return None def delete(self): norm = _norm(self.path) try: fs.rm(norm) logger.info(f"Deleted file: {norm}") except Exception as e: logger.error(f"delete file failed: {e}") def support_etag(self): return 0 def get_etag(self): return None def move_dest(self, dest_provider, dest_path, recursive, dry_run, environ): """Handle MOVE (rename).""" src_norm = _norm(self.path) dst_norm = _norm(dest_path) logger.info(f"MOVE {src_norm} -> {dst_norm}") try: _ensure_parent(dst_norm) data = fs.cat(src_norm) fs.pipe(dst_norm, data) fs.rm(src_norm) except Exception as e: logger.error(f"move failed: {e}") def copy(self, dest_provider, dest_path, environ, depth="infinity", dry_run=False): """Handle COPY.""" src_norm = _norm(self.path) dst_norm = _norm(dest_path) logger.info(f"COPY {src_norm} -> {dst_norm}") try: _ensure_parent(dst_norm) data = fs.cat(src_norm) fs.pipe(dst_norm, data) except Exception as e: logger.error(f"copy failed: {e}") class HFWriteBuffer(io.BytesIO): """Buffer for writing file content to HF Hub.""" def __init__(self, path): super().__init__() self.path = path def close(self): data = self.getvalue() try: fs.pipe(self.path, data) logger.info(f"Written {len(data)} bytes to {self.path}") except Exception as e: logger.error(f"Write failed: {e}") super().close() # ── Windows Compatibility Middleware ─────────────────────────── class WindowsCompatMiddleware: """Add headers required by Windows WebClient and fix OPTIONS response.""" def __init__(self, app): self.app = app def __call__(self, environ, start_response): if environ.get("REQUEST_METHOD") == "OPTIONS": # Windows requires these headers in OPTIONS response headers = [ ("DAV", "1, 2"), ("Allow", "OPTIONS, GET, HEAD, POST, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK"), ("Content-Length", "0"), ("MS-Author-Via", "DAV"), ("Content-Type", "httpd/unix-directory"), ] start_response("200 OK", headers) return [b""] def custom_start_response(status, response_headers, exc_info=None): # Ensure every response has Windows-required headers header_names = {k.lower() for k, _ in response_headers} if "ms-author-via" not in header_names: response_headers.append(("MS-Author-Via", "DAV")) if "dav" not in header_names: response_headers.append(("DAV", "1, 2")) if "access-control-allow-origin" not in header_names: response_headers.append(("Access-Control-Allow-Origin", "*")) return start_response(status, response_headers, exc_info) return self.app(environ, custom_start_response) # ── WsgiDAVApp setup ────────────────────────────────────────── def create_app(): user = os.environ.get("WEBDAV_USER", "user") passwd = os.environ.get("WEBDAV_PASS", "pass") app = WsgiDAVApp( { "host": "0.0.0.0", "port": 7860, "provider_mapping": {"/": HFDAVProvider()}, "simple_dc": { "user_mapping": {"*": {user: {"password": passwd, "roles": ["admin"]}}} }, "http_authenticator": { "DAV_auth_type": "basic", }, "verbose": 1, "dir_browser": {"enable": False}, # Disable browser UI, pure WebDAV } ) # Wrap with Windows compatibility middleware return WindowsCompatMiddleware(app) if __name__ == "__main__": from cheroot.wsgi import Server app = create_app() server = Server(("0.0.0.0", 7860), app) logger.info("Starting WebDAV server on :7860") try: server.start() except KeyboardInterrupt: server.stop()