Fred808 commited on
Commit
970ed33
·
verified ·
1 Parent(s): d86a9d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -19
app.py CHANGED
@@ -1,29 +1,51 @@
 
 
 
1
  from huggingface_hub import snapshot_download, HfApi
2
  import os
 
 
 
 
3
 
4
- # User's dataset repository ID
5
- DATASET_REPO_ID = "Fred808/helium_memory"
6
 
7
- # Model to download
8
- MODEL_REPO_ID = "openai/gpt-oss-120b"
9
 
10
  # Local directory to save downloaded model data temporarily
11
  DOWNLOAD_DIR = "./downloaded_model_data"
12
 
13
- # Hugging Face API Token (provided by user)
14
- HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
15
-
16
  # Ensure the download directory exists
17
  os.makedirs(DOWNLOAD_DIR, exist_ok=True)
18
 
19
- def download_full_model(repo_id, download_dir):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  """Downloads an entire model from a Hugging Face model repository."""
21
  print(f"Downloading full model from {repo_id}...")
22
  local_dir = snapshot_download(repo_id=repo_id, cache_dir=download_dir)
23
  print(f"Downloaded to: {local_dir}")
24
  return local_dir
25
 
26
- def upload_folder_to_dataset(dataset_repo_id, folder_path, path_in_repo, token):
27
  """Uploads a folder to a Hugging Face dataset repository."""
28
  api = HfApi(token=token)
29
  print(f"Uploading {folder_path} to {dataset_repo_id} at {path_in_repo}...")
@@ -35,19 +57,127 @@ def upload_folder_to_dataset(dataset_repo_id, folder_path, path_in_repo, token):
35
  )
36
  print("Upload complete!")
37
 
38
- if __name__ == "__main__":
39
- # 1. Download full model data
40
- downloaded_folder = download_full_model(MODEL_REPO_ID, DOWNLOAD_DIR)
 
41
 
42
- # 2. Upload the downloaded data to the user's dataset
43
- # The path in the dataset will be 'model_data/bert-base-uncased/'
44
- path_in_dataset = f"model_data/{MODEL_REPO_ID}/"
45
- upload_folder_to_dataset(DATASET_REPO_ID, downloaded_folder, path_in_dataset, HUGGINGFACE_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- print("Script finished successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
 
 
51
 
52
- if not HUGGINGFACE_TOKEN:
53
- raise ValueError("HUGGINGFACE_TOKEN environment variable not set.")
 
1
+ from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
2
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
+ from pydantic import BaseModel
4
  from huggingface_hub import snapshot_download, HfApi
5
  import os
6
+ import shutil
7
+ from typing import Optional
8
+ import asyncio
9
+ from concurrent.futures import ThreadPoolExecutor
10
 
11
+ app = FastAPI(title="Hugging Face Model Transfer API", version="1.0.0")
12
+ security = HTTPBearer()
13
 
14
+ # Thread pool for running blocking operations
15
+ executor = ThreadPoolExecutor(max_workers=2)
16
 
17
  # Local directory to save downloaded model data temporarily
18
  DOWNLOAD_DIR = "./downloaded_model_data"
19
 
 
 
 
20
  # Ensure the download directory exists
21
  os.makedirs(DOWNLOAD_DIR, exist_ok=True)
22
 
23
+ class DownloadRequest(BaseModel):
24
+ model_repo_id: str
25
+ download_dir: Optional[str] = DOWNLOAD_DIR
26
+
27
+ class UploadRequest(BaseModel):
28
+ dataset_repo_id: str
29
+ folder_path: str
30
+ path_in_repo: str
31
+
32
+ class TransferRequest(BaseModel):
33
+ model_repo_id: str
34
+ dataset_repo_id: str
35
+ path_in_repo: Optional[str] = None
36
+
37
+ def get_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
38
+ """Extract and validate Hugging Face token from Authorization header."""
39
+ return credentials.credentials
40
+
41
+ def download_full_model(repo_id: str, download_dir: str) -> str:
42
  """Downloads an entire model from a Hugging Face model repository."""
43
  print(f"Downloading full model from {repo_id}...")
44
  local_dir = snapshot_download(repo_id=repo_id, cache_dir=download_dir)
45
  print(f"Downloaded to: {local_dir}")
46
  return local_dir
47
 
48
+ def upload_folder_to_dataset(dataset_repo_id: str, folder_path: str, path_in_repo: str, token: str):
49
  """Uploads a folder to a Hugging Face dataset repository."""
50
  api = HfApi(token=token)
51
  print(f"Uploading {folder_path} to {dataset_repo_id} at {path_in_repo}...")
 
57
  )
58
  print("Upload complete!")
59
 
60
+ @app.get("/")
61
+ async def root():
62
+ """Health check endpoint."""
63
+ return {"message": "Hugging Face Model Transfer API is running"}
64
 
65
+ @app.post("/download")
66
+ async def download_model(request: DownloadRequest, token: str = Depends(get_token)):
67
+ """Download a model from Hugging Face model repository."""
68
+ try:
69
+ # Run the blocking download operation in a thread pool
70
+ loop = asyncio.get_event_loop()
71
+ local_dir = await loop.run_in_executor(
72
+ executor,
73
+ download_full_model,
74
+ request.model_repo_id,
75
+ request.download_dir
76
+ )
77
+
78
+ return {
79
+ "message": f"Model {request.model_repo_id} downloaded successfully",
80
+ "local_path": local_dir
81
+ }
82
+ except Exception as e:
83
+ raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
84
 
85
+ @app.post("/upload")
86
+ async def upload_folder(request: UploadRequest, token: str = Depends(get_token)):
87
+ """Upload a folder to a Hugging Face dataset repository."""
88
+ try:
89
+ # Check if folder exists
90
+ if not os.path.exists(request.folder_path):
91
+ raise HTTPException(status_code=404, detail=f"Folder not found: {request.folder_path}")
92
+
93
+ # Run the blocking upload operation in a thread pool
94
+ loop = asyncio.get_event_loop()
95
+ await loop.run_in_executor(
96
+ executor,
97
+ upload_folder_to_dataset,
98
+ request.dataset_repo_id,
99
+ request.folder_path,
100
+ request.path_in_repo,
101
+ token
102
+ )
103
+
104
+ return {
105
+ "message": f"Folder uploaded successfully to {request.dataset_repo_id}",
106
+ "path_in_repo": request.path_in_repo
107
+ }
108
+ except HTTPException:
109
+ raise
110
+ except Exception as e:
111
+ raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
112
+
113
+ @app.post("/transfer")
114
+ async def transfer_model(request: TransferRequest, background_tasks: BackgroundTasks, token: str = Depends(get_token)):
115
+ """Download a model and upload it to a dataset repository (combined operation)."""
116
+ try:
117
+ # Set default path in repo if not provided
118
+ path_in_repo = request.path_in_repo or f"model_data/{request.model_repo_id}/"
119
+
120
+ # Download the model
121
+ loop = asyncio.get_event_loop()
122
+ local_dir = await loop.run_in_executor(
123
+ executor,
124
+ download_full_model,
125
+ request.model_repo_id,
126
+ DOWNLOAD_DIR
127
+ )
128
+
129
+ # Upload to dataset
130
+ await loop.run_in_executor(
131
+ executor,
132
+ upload_folder_to_dataset,
133
+ request.dataset_repo_id,
134
+ local_dir,
135
+ path_in_repo,
136
+ token
137
+ )
138
+
139
+ # Clean up downloaded files in background
140
+ background_tasks.add_task(cleanup_download, local_dir)
141
+
142
+ return {
143
+ "message": f"Model {request.model_repo_id} transferred successfully to {request.dataset_repo_id}",
144
+ "path_in_repo": path_in_repo
145
+ }
146
+ except Exception as e:
147
+ raise HTTPException(status_code=500, detail=f"Transfer failed: {str(e)}")
148
+
149
+ def cleanup_download(local_dir: str):
150
+ """Clean up downloaded files."""
151
+ try:
152
+ if os.path.exists(local_dir):
153
+ shutil.rmtree(local_dir)
154
+ print(f"Cleaned up: {local_dir}")
155
+ except Exception as e:
156
+ print(f"Cleanup failed: {str(e)}")
157
+
158
+ @app.get("/status")
159
+ async def get_status():
160
+ """Get server status and available disk space."""
161
+ try:
162
+ disk_usage = shutil.disk_usage(DOWNLOAD_DIR)
163
+ return {
164
+ "status": "healthy",
165
+ "download_dir": DOWNLOAD_DIR,
166
+ "disk_space": {
167
+ "total": disk_usage.total,
168
+ "used": disk_usage.used,
169
+ "free": disk_usage.free
170
+ }
171
+ }
172
+ except Exception as e:
173
+ raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
174
+
175
+ if __name__ == "__main__":
176
+ import uvicorn
177
+ uvicorn.run(app, host="0.0.0.0", port=7860)
178
 
179
 
180
 
181
+ live
182
 
183
+ Jump to live