tryit / main.py
triflix's picture
Update main.py
c70a192 verified
import os
import shutil
import time
import glob
import asyncio
import mimetypes
import pathlib
import re
from datetime import datetime, timedelta
from fastapi import FastAPI, UploadFile, File, Form, Request, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
# CONFIG
API_VERSION = "1.0"
TMP_DIR = os.environ.get("TMP_DIR", "/tmp") # container /tmp by default
MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2GB
CLEANUP_INTERVAL_SECONDS = 600 # run cleanup every 10 minutes
EXPIRE_SECONDS = 3 * 3600 # 3 hours
CHUNK_SIZE = 1024 * 1024 # 1MB chunks
# Blacklist of extensions (lowercase, without dot)
DISALLOWED_EXT = {
"bat", "exe", "cmd", "sh", "msi", "ps1", "com", "scr"
}
# ensure tmp dir exists
os.makedirs(TMP_DIR, exist_ok=True)
app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None) # restrict /docs
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
def sanitize_slug(s: str) -> str:
s = re.sub(r"[^\w\-\.]", "", s) # allow letters, numbers, underscore, hyphen, dot
return s[:128]
def file_exists_for_slug(slug: str):
pattern = os.path.join(TMP_DIR, f"{slug}.*")
matches = glob.glob(pattern)
return matches[0] if matches else None
def make_file_path(slug: str, filename: str):
_, ext = os.path.splitext(filename)
ext = ext.lower()
return os.path.join(TMP_DIR, f"{slug}{ext}")
def gen_slug(length=8):
import secrets, string
alphabet = string.ascii_lowercase + string.digits
return ''.join(secrets.choice(alphabet) for _ in range(length))
async def save_upload_to_tmp(upload_file: UploadFile, dest_path: str):
total = 0
# write in chunks
with open(dest_path, "wb") as f:
while True:
chunk = await upload_file.read(CHUNK_SIZE)
if not chunk:
break
f.write(chunk)
total += len(chunk)
if total > MAX_BYTES:
# cleanup partial file
f.close()
try:
os.remove(dest_path)
except Exception:
pass
raise HTTPException(status_code=413, detail="File exceeds max size (2GB).")
return total
@app.on_event("startup")
async def startup_event():
# launch cleanup background task
loop = asyncio.get_event_loop()
loop.create_task(cleaner_task())
async def cleaner_task():
"""
Periodically remove files older than EXPIRE_SECONDS to keep /tmp tidy.
"""
while True:
try:
now = time.time()
for path in glob.glob(os.path.join(TMP_DIR, "*")):
try:
# only remove files (and ignore directories)
if os.path.isfile(path):
mtime = os.path.getmtime(path)
if now - mtime > EXPIRE_SECONDS:
os.remove(path)
except Exception:
continue
except Exception:
pass
await asyncio.sleep(CLEANUP_INTERVAL_SECONDS)
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {
"request": request,
"api_version": API_VERSION,
"max_bytes": MAX_BYTES,
"expire_seconds": EXPIRE_SECONDS
})
@app.post("/api/upload")
async def api_upload(file: UploadFile = File(...), custom_slug: str = Form(None)):
filename = file.filename or "upload"
_, ext = os.path.splitext(filename)
ext_l = ext.lower().lstrip(".")
if ext_l in DISALLOWED_EXT:
raise HTTPException(status_code=400, detail=f"Disallowed file type: {ext}")
# choose slug
if custom_slug:
slug = sanitize_slug(custom_slug)
if not slug:
raise HTTPException(status_code=400, detail="Invalid custom slug.")
if file_exists_for_slug(slug):
raise HTTPException(status_code=409, detail="Slug already exists.")
else:
# generate until free
for _ in range(8):
slug = gen_slug(8)
if not file_exists_for_slug(slug):
break
else:
# fallback long slug
slug = gen_slug(16)
dest = make_file_path(slug, filename)
# save file (enforces max size)
try:
bytes_written = await save_upload_to_tmp(file, dest)
except HTTPException as e:
raise e
except Exception as e:
# cleanup if any partial file
try:
if os.path.exists(dest):
os.remove(dest)
except Exception:
pass
raise HTTPException(status_code=500, detail="Failed to save file.")
# set mtime so cleanup knows created time (already set)
expires_at = datetime.utcnow() + timedelta(seconds=EXPIRE_SECONDS)
url = f"/f/{slug}"
return JSONResponse({
"slug": slug,
"url": url,
"filename": filename,
"size": bytes_written,
"expires_at": int(expires_at.timestamp())
})
@app.get("/f/{slug}")
async def serve_file(slug: str, dl: int = 0):
# find file by slug
path = file_exists_for_slug(slug)
if not path:
raise HTTPException(status_code=404, detail="File not found or expired.")
# serve with correct media_type
mime_type, _ = mimetypes.guess_type(path)
headers = {}
filename = os.path.basename(path)
if dl:
# force download
return FileResponse(path, media_type=mime_type or "application/octet-stream",
filename=filename)
# decide inline vs attachment by mime
inline_media = {"image", "video", "audio", "text", "application/pdf"}
mt = mime_type or ""
if any(mt.startswith(p) for p in inline_media) or mt == "application/pdf":
return FileResponse(path, media_type=mime_type or "application/octet-stream",
filename=filename)
else:
return FileResponse(path, media_type=mime_type or "application/octet-stream",
filename=filename)
@app.get("/api/info")
async def api_info():
curl_example = (
"curl -X POST -H \"Accept: application/json\" "
"-F \"file=@/path/to/file\" "
"https://triflix-tryit.hf.space/f --output -"
)
return {
"version": API_VERSION,
"upload_endpoint": "/api/upload",
"file_endpoint_example": "/f/<slug>",
"max_size_bytes": MAX_BYTES,
"expiry_seconds": EXPIRE_SECONDS,
"curl_example": curl_example,
}