dev
Update app.py: Windows compatibility, parent dir auto-creation, safe delete
d8f61ea
#!/usr/bin/env python3
"""WebDAV server backed by Hugging Face Hub (dataset storage).
Windows-compatible: includes all required DAV headers, proper PROPFIND
responses, parent-directory auto-creation, and safe delete logic.
"""
import os
import io
import logging
import stat
from datetime import datetime, timezone
from typing import Optional
from wsgidav.wsgidav_app import WsgiDAVApp
from wsgidav.dav_provider import DAVProvider, DAVCollection, DAVNonCollection
from huggingface_hub import HfFileSystem, hf_hub_download
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ── Configuration ──────────────────────────────────────────────
HF_REPO_ID = os.environ.get("HF_REPO_ID", "e2dew32/cloud-drive")
HF_REPO_TYPE = "dataset"
fs = HfFileSystem()
REPO_ROOT = f"{HF_REPO_TYPE}s/{HF_REPO_ID}"
def _norm(path: str) -> str:
"""Normalise a WebDAV path to an HF path."""
path = path.strip("/")
return f"{REPO_ROOT}/{path}" if path else REPO_ROOT
def _stat(path: str) -> dict:
"""Return info dict for a file or directory."""
try:
info = fs.info(path)
return info
except Exception:
return {}
def _ensure_parent(hf_path: str):
"""Create parent directories on HF Hub if they don't exist."""
parent = "/".join(hf_path.split("/")[:-1])
if not parent or parent == REPO_ROOT:
return
try:
if not fs.isdir(parent):
fs.mkdir(parent, exist_ok=True, parents=True)
logger.info(f"Created parent dir: {parent}")
except Exception as e:
logger.warning(f"ensure_parent failed for {parent}: {e}")
# ── DAV Provider ───────────────────────────────────────────────
class HFDAVProvider(DAVProvider):
"""Expose HF Hub storage as a WebDAV tree."""
def __init__(self):
super().__init__()
def get_resource_inst(self, path, environ):
norm = _norm(path)
info = _stat(norm)
if not info:
return None
if info.get("type") == "directory":
return HFDavCollection(path, environ, info)
return HFDavNonCollection(path, environ, info)
def is_collection(self, path):
info = _stat(_norm(path))
return info.get("type") == "directory"
class HFDavCollection(DAVCollection):
def __init__(self, path, environ, info):
super().__init__(path, environ)
self._info = info
def get_display_name(self):
return self.path.rstrip("/").split("/")[-1] or "/"
def get_member_names(self):
norm = _norm(self.path)
try:
entries = fs.ls(norm, detail=False)
except Exception:
return []
names = []
for e in entries:
name = e.rsplit("/", 1)[-1]
# Filter out . and .. to prevent Windows infinite loops
if name not in (".", "..", ""):
names.append(name)
return names
def get_member(self, name):
child_path = self.path.rstrip("/") + "/" + name
return self.provider.get_resource_inst(child_path, self.environ)
def create_empty_resource(self, name):
"""Called when client does PUT with no body first."""
child_path = self.path.rstrip("/") + "/" + name
norm = _norm(child_path)
_ensure_parent(norm)
return HFDavNonCollection(
child_path, self.environ, {"type": "file", "size": 0, "name": name}
)
def create_collection(self, name):
"""MKCOL – create a directory."""
new_path = self.path.rstrip("/") + "/" + name
norm = _norm(new_path)
try:
fs.mkdir(norm, exist_ok=True, parents=True)
logger.info(f"MKCOL: {norm}")
except Exception as e:
logger.error(f"mkdir failed: {e}")
def get_creation_date(self):
return datetime.now(timezone.utc)
def get_modified_date(self):
return datetime.now(timezone.utc)
def get_content_length(self):
return 0
def get_content_type(self):
return "httpd/unix-directory"
def is_collection(self):
return True
def support_etag(self):
return 0
def get_etag(self):
return None
def is_property_locked(self, name):
return False
def get_property_names(self, is_allprop):
return []
def get_property_value(self, name):
return None
def support_recursive_delete(self, path):
return True
def delete(self):
norm = _norm(self.path)
# Safety: never delete the repo root
if norm == REPO_ROOT:
logger.warning("Blocked attempt to delete repo root")
return
try:
entries = fs.ls(norm, detail=False)
for e in entries:
try:
if fs.isdir(e):
fs.rm(e, recursive=True)
else:
fs.rm(e)
except Exception as inner:
logger.warning(f"delete entry {e}: {inner}")
fs.rmdir(norm)
logger.info(f"Deleted collection: {norm}")
except Exception as e:
logger.error(f"delete collection failed: {e}")
class HFDavNonCollection(DAVNonCollection):
def __init__(self, path, environ, info):
super().__init__(path, environ)
self._info = info
def get_display_name(self):
return self.path.rstrip("/").split("/")[-1]
def get_content_length(self):
return self._info.get("size", 0)
def get_content_type(self):
name = self.get_display_name()
if "." in name:
ext = name.rsplit(".", 1)[-1].lower()
types = {
"txt": "text/plain",
"md": "text/markdown",
"json": "application/json",
"html": "text/html",
"css": "text/css",
"js": "application/javascript",
"xml": "application/xml",
"csv": "text/csv",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"svg": "image/svg+xml",
"pdf": "application/pdf",
"zip": "application/zip",
"mp3": "audio/mpeg",
"mp4": "video/mp4",
}
return types.get(ext, "application/octet-stream")
return "application/octet-stream"
def get_creation_date(self):
return datetime.now(timezone.utc)
def get_modified_date(self):
lc = self._info.get("last_commit") or {}
mtime = lc.get("date")
if mtime:
try:
return datetime.fromisoformat(mtime.replace("Z", "+00:00"))
except Exception:
pass
return datetime.now(timezone.utc)
def get_content(self):
norm = _norm(self.path)
try:
with fs.open(norm, "rb") as f:
return io.BytesIO(f.read())
except Exception as e:
logger.error(f"read failed: {e}")
return io.BytesIO(b"")
def begin_write(self, content_type=None):
norm = _norm(self.path)
_ensure_parent(norm)
logger.info(f"begin_write: {norm}")
return HFWriteBuffer(norm)
def is_property_locked(self, name):
return False
def get_property_names(self, is_allprop):
return []
def get_property_value(self, name):
return None
def delete(self):
norm = _norm(self.path)
try:
fs.rm(norm)
logger.info(f"Deleted file: {norm}")
except Exception as e:
logger.error(f"delete file failed: {e}")
def support_etag(self):
return 0
def get_etag(self):
return None
def move_dest(self, dest_provider, dest_path, recursive, dry_run, environ):
"""Handle MOVE (rename)."""
src_norm = _norm(self.path)
dst_norm = _norm(dest_path)
logger.info(f"MOVE {src_norm} -> {dst_norm}")
try:
_ensure_parent(dst_norm)
data = fs.cat(src_norm)
fs.pipe(dst_norm, data)
fs.rm(src_norm)
except Exception as e:
logger.error(f"move failed: {e}")
def copy(self, dest_provider, dest_path, environ, depth="infinity", dry_run=False):
"""Handle COPY."""
src_norm = _norm(self.path)
dst_norm = _norm(dest_path)
logger.info(f"COPY {src_norm} -> {dst_norm}")
try:
_ensure_parent(dst_norm)
data = fs.cat(src_norm)
fs.pipe(dst_norm, data)
except Exception as e:
logger.error(f"copy failed: {e}")
class HFWriteBuffer(io.BytesIO):
"""Buffer for writing file content to HF Hub."""
def __init__(self, path):
super().__init__()
self.path = path
def close(self):
data = self.getvalue()
try:
fs.pipe(self.path, data)
logger.info(f"Written {len(data)} bytes to {self.path}")
except Exception as e:
logger.error(f"Write failed: {e}")
super().close()
# ── Windows Compatibility Middleware ───────────────────────────
class WindowsCompatMiddleware:
"""Add headers required by Windows WebClient and fix OPTIONS response."""
def __init__(self, app):
self.app = app
def __call__(self, environ, start_response):
if environ.get("REQUEST_METHOD") == "OPTIONS":
# Windows requires these headers in OPTIONS response
headers = [
("DAV", "1, 2"),
("Allow", "OPTIONS, GET, HEAD, POST, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK"),
("Content-Length", "0"),
("MS-Author-Via", "DAV"),
("Content-Type", "httpd/unix-directory"),
]
start_response("200 OK", headers)
return [b""]
def custom_start_response(status, response_headers, exc_info=None):
# Ensure every response has Windows-required headers
header_names = {k.lower() for k, _ in response_headers}
if "ms-author-via" not in header_names:
response_headers.append(("MS-Author-Via", "DAV"))
if "dav" not in header_names:
response_headers.append(("DAV", "1, 2"))
if "access-control-allow-origin" not in header_names:
response_headers.append(("Access-Control-Allow-Origin", "*"))
return start_response(status, response_headers, exc_info)
return self.app(environ, custom_start_response)
# ── WsgiDAVApp setup ──────────────────────────────────────────
def create_app():
user = os.environ.get("WEBDAV_USER", "user")
passwd = os.environ.get("WEBDAV_PASS", "pass")
app = WsgiDAVApp(
{
"host": "0.0.0.0",
"port": 7860,
"provider_mapping": {"/": HFDAVProvider()},
"simple_dc": {
"user_mapping": {"*": {user: {"password": passwd, "roles": ["admin"]}}}
},
"http_authenticator": {
"DAV_auth_type": "basic",
},
"verbose": 1,
"dir_browser": {"enable": False}, # Disable browser UI, pure WebDAV
}
)
# Wrap with Windows compatibility middleware
return WindowsCompatMiddleware(app)
if __name__ == "__main__":
from cheroot.wsgi import Server
app = create_app()
server = Server(("0.0.0.0", 7860), app)
logger.info("Starting WebDAV server on :7860")
try:
server.start()
except KeyboardInterrupt:
server.stop()