bigbossmonster's picture
Rename app.py to main.py
4e1a716 verified
import os
import shutil
import subprocess
import requests
import re
from urllib.parse import urlparse
from typing import Optional, List
from pathlib import Path
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from huggingface_hub import HfApi
from pySmartDL import SmartDL
# --- dependency check ---
try:
subprocess.run(["pip", "install", "--upgrade", "gdown"], check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
print(f"Error occurred installing gdown: {e.stderr}")
# --- Configuration ---
DIR = "download"
os.makedirs(DIR, exist_ok=True)
# Helper to get token (checks env var first)
def get_token():
return os.getenv('HF_token')
def get_civit_key():
return os.getenv('civitai_api')
app = FastAPI(title="File Management API")
# --- Pydantic Models for Request Validation ---
class ProcessRequest(BaseModel):
urls: str # Accepts newline separated string to match original logic, or you could change to List[str]
hf_path: Optional[str] = "bigbossmonster/output"
selection: str = "default" # options: default, wget, SmartDl, aria2, curl
hf_check: bool = False # If True, skip upload
upload_type: str = "model" # options: model, dataset
hf_token: Optional[str] = None # Optional override for token
# --- Core Logic Functions (Refactored) ---
def wipe_folder_logic():
try:
if os.path.exists(DIR):
shutil.rmtree(DIR)
os.makedirs(DIR, exist_ok=True)
return f"All files in '{DIR}' have been wiped successfully."
else:
return f"The folder '{DIR}' does not exist."
except Exception as e:
return f"Error wiping folder: {str(e)}"
def extract_file_name(url):
parsed_url = urlparse(url)
path_segments = parsed_url.path.split('/')
return path_segments[-1]
def download_file(url, save_path, selection):
if selection == "default":
try:
with requests.get(url, stream=True) as response:
response.raise_for_status()
with open(save_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"File downloaded to {save_path}")
except requests.RequestException as e:
print(f"Failed to download file: {e}")
elif selection in ["wget", "aria2", "curl"]:
try:
os.makedirs(DIR, exist_ok=True)
command = []
if selection == "wget":
command = ["wget", url, "-P", DIR]
elif selection == "aria2":
command = ["aria2c", "-x16", "-s16", "-d", DIR, url]
elif selection == "curl":
command = ["curl", "-o", save_path, url]
subprocess.run(command, check=True)
print(f"File downloaded to {DIR} using {selection}")
except subprocess.CalledProcessError as e:
print(f"Failed to download file using {selection}: {e}")
else: # SmartDL
try:
obj = SmartDL(url, save_path)
obj.start()
print(f"File downloaded successfully to: {obj.get_dest()} using SmartDL")
except Exception as e:
print(f"Failed to download file using SmartDL: {e}")
def extract_google_drive_id(url):
drive_id_pattern = re.compile(
r'(?:drive.google.com/.*?id=|drive.google.com/file/d/|drive.google.com/open\?id=|drive.google.com/uc\?id=|drive.google.com/drive/folders/)([a-zA-Z0-9_-]+)'
)
match = drive_id_pattern.search(url)
return match.group(1) if match else None
def download_google_drive(url):
file_id = extract_google_drive_id(url)
if file_id:
download_url = f"https://drive.google.com/uc?id={file_id}"
print(f"Downloading from: {download_url}")
try:
# We must be inside the directory for gdown to work as intended with --folder sometimes
cwd = os.getcwd()
os.chdir(DIR)
if "folders" in url:
subprocess.run(["gdown", "--remaining-ok", "--folder" , url], check=True)
else:
subprocess.run(["gdown", download_url], check=True)
os.chdir(cwd) # Switch back
return f"Downloaded: {download_url}"
except subprocess.CalledProcessError as e:
os.chdir(cwd) # Switch back even on error
return f"Failed to download: {e}"
except Exception as e:
return f"Error: {e}"
else:
return f"Invalid Google Drive URL: {url}"
def download_hugging_face(url, selection):
file_name = extract_file_name(url)
save_path = os.path.join(DIR, file_name)
download_file(url, save_path, selection)
return f"Downloaded: {url}"
def download_katfile(url, selection):
apiKey = "699996yph6h88a7rc6c1g8"
try:
parts = url.split('/')
domain = parts[2]
filecode = parts[3]
cloneurl = f"https://{domain}/api/file/clone?key={apiKey}&file_code={filecode}"
response = requests.get(cloneurl)
if response.status_code == 200:
json_data = response.json()
download_url = json_data.get('result', {}).get('url')
if not download_url: return "Failed to get Katfile URL"
parts_final = download_url.split('/')
filecodex = parts_final[3]
final_url = f"https://{domain}/api/file/direct_link?key={apiKey}&file_code={filecodex}"
response = requests.get(final_url)
if response.status_code == 200:
json_data = response.json()
real_dl_url = json_data.get('result', {}).get('url')
save_path = os.path.join(DIR, extract_file_name(real_dl_url))
download_file(real_dl_url, save_path, selection)
return f"Downloaded: {real_dl_url}"
except Exception as e:
print(f"Error downloading from Katfile: {str(e)}")
return "Error downloading Katfile"
def download_civitai(url, selection):
civit_key = get_civit_key()
parsed_url = urlparse(url)
path_components = parsed_url.path.split("/")
modelVersionId = path_components[-1]
save_path = os.path.join(DIR, f"{modelVersionId}.safecheck") # Temp name if unknown, usually CivitAI needs content-disposition
# CivitAI specific logic
if selection == "curl":
command = [
"curl", "-L",
"-H", f"Authorization: Bearer {civit_key}",
"-o", save_path, # Added output path for curl
f"{url}"
]
subprocess.run(command)
else:
# Construct API URL
api_url = f"https://civitai.com/api/download/models/{modelVersionId}?token={civit_key}"
# For standard downloads we generally want the real filename,
# but sticking to your logic:
download_file(api_url, os.path.join(DIR, "civitai_model"), selection)
return f"Downloaded: {url}"
def handle_download(url, selection):
try:
if "drive.google.com" in url:
return download_google_drive(url)
elif "huggingface.co" in url:
return download_hugging_face(url, selection)
elif "civitai.com" in url:
return download_civitai(url, selection)
elif "katfile.com" in url:
return download_katfile(url, selection)
else:
save_path = os.path.join(DIR, extract_file_name(url))
download_file(url, save_path, selection)
return f"Downloaded: {url}"
except Exception as e:
return f"Error occurred: {str(e)}"
def upload_to_hf(hf_path, upload_type, token_val):
repo_id = hf_path if hf_path else "bigbossmonster/output"
commit_message = "Upload folder with Python script"
if os.path.exists(DIR):
api = HfApi()
try:
api.upload_folder(
folder_path=DIR,
repo_id=repo_id,
commit_message=commit_message,
token=token_val,
repo_type=upload_type,
)
return f"Upload success! Folder is available at https://huggingface.co/{repo_id}"
except Exception as e:
return f"Failed to upload folder. Error: {e}"
else:
return f"The folder '{DIR}' does not exist."
# --- API Endpoints ---
@app.get("/")
def home():
return {"message": "File Management API is running. Go to /docs for Swagger UI."}
@app.post("/process")
def process_batch(request: ProcessRequest):
"""
Main endpoint to download files and optionally upload to HF.
"""
# 1. Process Downloads
download_results = []
urls_list = request.urls.splitlines()
for url in urls_list:
url = url.strip()
if url:
result = handle_download(url, request.selection)
download_results.append(result)
download_msg = "\n".join(download_results)
# 2. Process Upload (if checked)
upload_msg = "Do Not upload to HF"
if not request.hf_check:
# Determine token: Request > Env Var
active_token = request.hf_token if request.hf_token else get_token()
upload_msg = upload_to_hf(request.hf_path, request.upload_type, active_token)
return {
"download_status": download_msg,
"upload_status": upload_msg
}
@app.post("/wipe")
def wipe_folder():
"""Wipes the download directory."""
return {"status": wipe_folder_logic()}
@app.get("/files")
def list_files_endpoint():
"""Lists files in the download directory."""
try:
if not os.path.exists(DIR):
return {"files": [], "message": "Directory does not exist."}
files = os.listdir(DIR)
return {"files": files, "count": len(files)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)