File size: 3,736 Bytes
7a826c8
a8f0097
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a826c8
 
 
a8f0097
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a826c8
a8f0097
 
 
7a826c8
 
 
 
a8f0097
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks, HTTPException, Depends
from fastapi.responses import HTMLResponse, JSONResponse
from pathlib import Path
from typing import List
from uploader import HuggingFaceUploader
import shutil
import uuid
import os

app = FastAPI()

# Initialize HuggingFaceUploader with your token and repo_id
hf_token = os.getenv("HF_TOKEN")
repo_id = os.getenv("REPO_ID")
uploader = HuggingFaceUploader(hf_token, repo_id)

# Directories for temporary uploads and status tracking
UPLOAD_FOLDER = Path("temp_uploads")
STATUS_FOLDER = Path("status")
UPLOAD_FOLDER.mkdir(exist_ok=True)
STATUS_FOLDER.mkdir(exist_ok=True)

# Access key for upload authorization
ACCESS_KEY = os.getenv("PASSWORD")

def save_progress(status_file, filename, status):
    """Helper function to update progress to a status file."""
    with status_file.open("a", encoding="utf-8") as f:
        f.write(f"{filename}: {status}\n")

def upload_file_task(temp_folder_path: Path, destination_folder: str, status_file: Path):
    """Background task to upload files and log progress."""
    try:
        total_files = len(list(temp_folder_path.rglob('*')))
        uploaded_files = 0

        for local_path in temp_folder_path.rglob('*'):
            if local_path.is_file():
                relative_path = local_path.relative_to(temp_folder_path)
                path_in_repo = str(Path(destination_folder) / relative_path).replace("\\", "/")
                
                uploader.api.upload_file(
                    path_or_fileobj=str(local_path),
                    path_in_repo=path_in_repo,
                    repo_id=uploader.repo_id,
                    repo_type="model"
                )

                uploaded_files += 1
                save_progress(status_file, str(local_path.name), "Uploaded")
                # Log current progress
                save_progress(status_file, "Progress", f"{uploaded_files}/{total_files} files uploaded.")

    except Exception as e:
        save_progress(status_file, "error", str(e))
    finally:
        # Clean up files
        for file in temp_folder_path.glob('*'):
            file.unlink()
        temp_folder_path.rmdir()

@app.get("/", response_class=HTMLResponse)
async def main():
    with open("index.html") as f:
        return HTMLResponse(content=f.read())

@app.post("/upload")
async def upload_files(
    background_tasks: BackgroundTasks,
    access_key: str = Form(...),
    destination_folder: str = Form(...),
    files: List[UploadFile] = File(...)
):
    # Verify access key
    if access_key != ACCESS_KEY:
        raise HTTPException(status_code=403, detail="Invalid access key")

    temp_folder_path = UPLOAD_FOLDER / uuid.uuid4().hex
    temp_folder_path.mkdir(parents=True, exist_ok=True)

    # Save uploaded files temporarily
    for file in files:
        file_path = temp_folder_path / file.filename
        with file_path.open("wb") as buffer:
            shutil.copyfileobj(file.file, buffer)
    
    # Create a status file for tracking progress
    status_file = STATUS_FOLDER / f"{temp_folder_path.name}.txt"
    background_tasks.add_task(upload_file_task, temp_folder_path, destination_folder, status_file)
    return JSONResponse(content={"status_id": temp_folder_path.name})

@app.get("/progress/{status_id}")
async def get_progress(status_id: str):
    """Get the progress of the upload from the status file."""
    status_file = STATUS_FOLDER / f"{status_id}.txt"
    if not status_file.exists():
        raise HTTPException(status_code=404, detail="Status not found")
    
    with status_file.open("r", encoding="utf-8") as f:
        progress = f.readlines()
    return JSONResponse(content={"progress": progress})