Update app.py
Browse files
app.py
CHANGED
|
@@ -29,6 +29,10 @@ HARDCODED_DATASET_REPO_ID = "Fred808/helium_memory" # Change this to your datas
|
|
| 29 |
# Hardcoded path in repository
|
| 30 |
HARDCODED_PATH_IN_REPO = "model_data/"
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def download_model_files() -> str:
|
| 33 |
"""Download all files from a model repository to a temporary directory."""
|
| 34 |
print(f"Downloading all files from model {HARDCODED_MODEL_REPO_ID}...")
|
|
@@ -109,8 +113,9 @@ def cleanup_download(temp_dir: str):
|
|
| 109 |
except Exception as e:
|
| 110 |
print(f"Cleanup failed: {str(e)}")
|
| 111 |
|
| 112 |
-
async def
|
| 113 |
-
"""
|
|
|
|
| 114 |
temp_dir = None
|
| 115 |
try:
|
| 116 |
# Download the model files
|
|
@@ -134,9 +139,12 @@ async def transfer_model():
|
|
| 134 |
)
|
| 135 |
|
| 136 |
print(f"Model {HARDCODED_MODEL_REPO_ID} transferred successfully to {HARDCODED_DATASET_REPO_ID}")
|
|
|
|
| 137 |
|
| 138 |
except Exception as e:
|
| 139 |
-
|
|
|
|
|
|
|
| 140 |
raise
|
| 141 |
finally:
|
| 142 |
# Clean up downloaded files
|
|
@@ -146,31 +154,35 @@ async def transfer_model():
|
|
| 146 |
@app.get("/")
|
| 147 |
async def root():
|
| 148 |
"""Health check endpoint."""
|
|
|
|
| 149 |
return {
|
| 150 |
"message": "Hugging Face Model Transfer Service is running",
|
| 151 |
-
"
|
| 152 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
}
|
| 154 |
|
| 155 |
@app.on_event("startup")
|
| 156 |
async def startup_event():
|
| 157 |
"""Run the transfer process when the application starts."""
|
| 158 |
print("Starting model transfer process...")
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
@app.get("/status")
|
| 162 |
-
async def get_status():
|
| 163 |
-
"""Get server status."""
|
| 164 |
-
try:
|
| 165 |
-
return {
|
| 166 |
-
"status": "healthy",
|
| 167 |
-
"model": HARDCODED_MODEL_REPO_ID,
|
| 168 |
-
"dataset": HARDCODED_DATASET_REPO_ID,
|
| 169 |
-
"path_in_repo": HARDCODED_PATH_IN_REPO
|
| 170 |
-
}
|
| 171 |
-
except Exception as e:
|
| 172 |
-
raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
|
| 173 |
|
| 174 |
if __name__ == "__main__":
|
| 175 |
import uvicorn
|
|
|
|
| 176 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 29 |
# Hardcoded path in repository
|
| 30 |
HARDCODED_PATH_IN_REPO = "model_data/"
|
| 31 |
|
| 32 |
+
# Track transfer status
|
| 33 |
+
transfer_completed = False
|
| 34 |
+
transfer_error = None
|
| 35 |
+
|
| 36 |
def download_model_files() -> str:
|
| 37 |
"""Download all files from a model repository to a temporary directory."""
|
| 38 |
print(f"Downloading all files from model {HARDCODED_MODEL_REPO_ID}...")
|
|
|
|
| 113 |
except Exception as e:
|
| 114 |
print(f"Cleanup failed: {str(e)}")
|
| 115 |
|
| 116 |
+
async def run_transfer():
|
| 117 |
+
"""Run the transfer process and update status."""
|
| 118 |
+
global transfer_completed, transfer_error
|
| 119 |
temp_dir = None
|
| 120 |
try:
|
| 121 |
# Download the model files
|
|
|
|
| 139 |
)
|
| 140 |
|
| 141 |
print(f"Model {HARDCODED_MODEL_REPO_ID} transferred successfully to {HARDCODED_DATASET_REPO_ID}")
|
| 142 |
+
transfer_completed = True
|
| 143 |
|
| 144 |
except Exception as e:
|
| 145 |
+
error_msg = f"Transfer failed: {str(e)}"
|
| 146 |
+
print(error_msg)
|
| 147 |
+
transfer_error = error_msg
|
| 148 |
raise
|
| 149 |
finally:
|
| 150 |
# Clean up downloaded files
|
|
|
|
| 154 |
@app.get("/")
|
| 155 |
async def root():
|
| 156 |
"""Health check endpoint."""
|
| 157 |
+
status = "completed" if transfer_completed else "running" if transfer_error is None else "failed"
|
| 158 |
return {
|
| 159 |
"message": "Hugging Face Model Transfer Service is running",
|
| 160 |
+
"status": status,
|
| 161 |
+
"model": HARDCODED_MODEL_REPO_ID,
|
| 162 |
+
"dataset": HARDCODED_DATASET_REPO_ID,
|
| 163 |
+
"error": transfer_error
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
@app.get("/status")
|
| 167 |
+
async def get_status():
|
| 168 |
+
"""Get transfer status."""
|
| 169 |
+
status = "completed" if transfer_completed else "running" if transfer_error is None else "failed"
|
| 170 |
+
return {
|
| 171 |
+
"status": status,
|
| 172 |
+
"model": HARDCODED_MODEL_REPO_ID,
|
| 173 |
+
"dataset": HARDCODED_DATASET_REPO_ID,
|
| 174 |
+
"path_in_repo": HARDCODED_PATH_IN_REPO,
|
| 175 |
+
"error": transfer_error
|
| 176 |
}
|
| 177 |
|
| 178 |
@app.on_event("startup")
|
| 179 |
async def startup_event():
|
| 180 |
"""Run the transfer process when the application starts."""
|
| 181 |
print("Starting model transfer process...")
|
| 182 |
+
# Run transfer in background without waiting for completion
|
| 183 |
+
asyncio.create_task(run_transfer())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
if __name__ == "__main__":
|
| 186 |
import uvicorn
|
| 187 |
+
# Run the server indefinitely
|
| 188 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|