proxy / app.py
Dev
Proxy v2
7a086c0
Raw
History Blame Contribute Delete
10.4 kB
import http.server
import json
import os
import time
import shutil
import threading
import urllib.request
import hashlib
import re
from pathlib import Path
from datetime import datetime, timedelta
from urllib.parse import urlparse, parse_qs
HF_TOKEN = os.environ.get("HF_TOKEN", "")
CACHE_DIR = Path("/tmp/proxy_cache")
CACHE_DIR.mkdir(exist_ok=True)
CACHE_TTL = 5 * 3600
MAX_CACHE_BYTES = 2 * 1024 * 1024 * 1024
LOCK = threading.Lock()
class SmartCache:
def __init__(self):
self._start_cleaner()
def _start_cleaner(self):
def loop():
while True:
time.sleep(180)
self.cleanup()
t = threading.Thread(target=loop, daemon=True)
t.start()
def get(self, key):
fpath = CACHE_DIR / key
meta = CACHE_DIR / f"{key}.meta"
if fpath.exists() and meta.exists():
try:
with open(meta) as f:
m = json.load(f)
age = time.time() - m.get("ts", 0)
if age < CACHE_TTL:
m["hits"] = m.get("hits", 0) + 1
with open(meta, "w") as f:
json.dump(m, f)
return fpath
else:
fpath.unlink(missing_ok=True)
meta.unlink(missing_ok=True)
except:
pass
return None
def put(self, key, src_path, content_type="video/mp4"):
with LOCK:
fpath = CACHE_DIR / key
shutil.copy2(src_path, fpath)
meta = CACHE_DIR / f"{key}.meta"
with open(meta, "w") as f:
json.dump({
"ts": time.time(),
"expires": time.time() + CACHE_TTL,
"hits": 1,
"size": os.path.getsize(fpath),
"content_type": content_type,
}, f)
def stats(self):
files = 0
size = 0
with LOCK:
for f in CACHE_DIR.iterdir():
if f.is_file() and not f.name.endswith(".meta"):
files += 1
size += f.stat().st_size
return files, size
def cleanup(self):
with LOCK:
total = sum(f.stat().st_size for f in CACHE_DIR.iterdir()
if f.is_file() and not f.name.endswith(".meta"))
if total < MAX_CACHE_BYTES:
return
entries = []
for f in CACHE_DIR.iterdir():
if f.is_file() and not f.name.endswith(".meta"):
meta = CACHE_DIR / f"{f.name}.meta"
hits = 0
mtime = f.stat().st_mtime
if meta.exists():
try:
with open(meta) as mf:
m = json.load(mf)
hits = m.get("hits", 0)
mtime = m.get("ts", mtime)
except:
pass
entries.append((hits, mtime, f.stat().st_size, f))
entries.sort()
target = int(MAX_CACHE_BYTES * 0.5)
for hits, mtime, sz, f in entries:
if total <= target:
break
f.unlink(missing_ok=True)
(CACHE_DIR / f"{f.name}.meta").unlink(missing_ok=True)
total -= sz
cache = SmartCache()
class Handler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
parsed = urlparse(self.path)
params = parse_qs(parsed.query)
if parsed.path == "/health":
files, size = cache.stats()
free = shutil.disk_usage("/tmp").free
self._json(200, {
"status": "ok",
"cached_files": files,
"cache_size_bytes": size,
"cache_size_gb": round(size / (1024**3), 2),
"disk_free_bytes": free,
"disk_free_gb": round(free / (1024**3), 1),
"ttl_hours": CACHE_TTL / 3600,
})
return
if parsed.path.startswith("/stream/"):
key = parsed.path.split("/stream/", 1)[-1]
if not key:
self._json(400, {"error": "Missing file key"})
return
fpath = cache.get(key)
if fpath:
range_header = self.headers.get("Range", "")
self._stream_file(fpath, key, range_header)
else:
self._json(404, {"error": "Not in cache. Use /preload first."})
return
if parsed.path == "/list":
files = []
for f in sorted(CACHE_DIR.iterdir()):
if f.is_file() and not f.name.endswith(".meta"):
meta = CACHE_DIR / f"{f.name}.meta"
expires = 0
hits = 0
if meta.exists():
try:
with open(meta) as mf:
m = json.load(mf)
expires = m.get("expires", 0)
hits = m.get("hits", 0)
except:
pass
files.append({
"key": f.name,
"size": f.stat().st_size,
"expires_at": datetime.fromtimestamp(expires).isoformat() if expires else "N/A",
"hits": hits,
})
self._json(200, {"files": files, "count": len(files)})
return
self._json(404, {"error": "Not found. Try /health, /stream/{key}, /list"})
def do_POST(self):
length = int(self.headers.get("Content-Length", 0))
body = json.loads(self.rfile.read(length)) if length else {}
parsed = urlparse(self.path)
path = parsed.path
if path == "/cache":
key = body.get("key", "")
url = body.get("url", "")
if not key or not url:
self._json(400, {"error": "Missing key or url"})
return
try:
temp = CACHE_DIR / f"dl_{int(time.time()*1000)}_{os.urandom(2).hex()}"
req = urllib.request.Request(url, headers={"Authorization": f"Bearer {HF_TOKEN}"})
with urllib.request.urlopen(req, timeout=300) as src:
with open(temp, "wb") as f:
shutil.copyfileobj(src, f)
cache.put(key, temp)
temp.unlink()
self._json(200, {"status": "cached", "key": key})
except Exception as e:
self._json(500, {"error": str(e)[:200]})
return
if path == "/preload":
dataset = body.get("dataset", "")
file_name = body.get("file_name", "")
if not dataset or not file_name:
self._json(400, {"error": "Missing dataset or file_name"})
return
key = f"{dataset}/{file_name}"
dl_url = f"https://huggingface.co/datasets/{dataset}/resolve/main/{file_name}"
try:
temp = CACHE_DIR / f"dl_{int(time.time()*1000)}_{os.urandom(2).hex()}"
req = urllib.request.Request(dl_url, headers={"Authorization": f"Bearer {HF_TOKEN}"})
with urllib.request.urlopen(req, timeout=300) as src:
with open(temp, "wb") as f:
shutil.copyfileobj(src, f)
cache.put(key, temp, "video/mp4")
temp.unlink()
self._json(200, {"status": "preloaded", "key": key})
except Exception as e:
self._json(500, {"error": str(e)[:200]})
return
if path == "/flush":
with LOCK:
count = 0
for f in CACHE_DIR.iterdir():
if f.is_file():
f.unlink()
count += 1
self._json(200, {"status": "flushed", "removed": count})
return
self._json(404, {"error": "Not found"})
def _stream_file(self, fpath, filename, range_header):
file_size = fpath.stat().st_size
content_type = "video/mp4"
if range_header:
match = re.match(r"bytes=(\d+)-(\d*)", range_header)
if match:
start = int(match.group(1))
end = int(match.group(2)) if match.group(2) else file_size - 1
length = end - start + 1
self.send_response(206)
self.send_header("Content-Range", f"bytes {start}-{end}/{file_size}")
else:
start, end, length = 0, file_size - 1, file_size
self.send_response(200)
else:
start, end, length = 0, file_size - 1, file_size
self.send_response(200)
self.send_header("Content-Type", content_type)
self.send_header("Content-Length", str(length))
self.send_header("Accept-Ranges", "bytes")
self.send_header("Cache-Control", f"public, max-age={CACHE_TTL}")
self.send_header("Content-Disposition", f'inline; filename="{filename}"')
self.end_headers()
with open(fpath, "rb") as f:
f.seek(start)
remaining = length
while remaining > 0:
chunk_size = min(65536, remaining)
chunk = f.read(chunk_size)
if not chunk:
break
self.wfile.write(chunk)
remaining -= len(chunk)
def _json(self, code, data):
self.send_response(code)
self.send_header("Content-Type", "application/json")
self.send_header("Access-Control-Allow-Origin", "*")
self.end_headers()
self.wfile.write(json.dumps(data, ensure_ascii=False, default=str).encode())
def log_message(self, fmt, *args):
pass
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print(f"Proxy v2 on :{port}")
print(f"Cache: {CACHE_DIR} | TTL: {CACHE_TTL}s | Max: {MAX_CACHE_BYTES/(1024**3):.0f}GB")
httpd = http.server.HTTPServer(("0.0.0.0", port), Handler)
httpd.serve_forever()