Update app.py
Browse files
app.py
CHANGED
|
@@ -602,23 +602,54 @@ def check_existing_model(model_path: str) -> bool:
|
|
| 602 |
return all(f in model_files for f in required_files) and has_weights
|
| 603 |
|
| 604 |
async def download_model_files():
|
| 605 |
-
"""Downloads the model files using
|
| 606 |
try:
|
| 607 |
print(f"[INFO] Processing model from {Settings.MODEL_REPO}...")
|
| 608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
# Create models directory
|
| 610 |
models_dir = os.path.join(os.getcwd(), "models")
|
| 611 |
os.makedirs(models_dir, exist_ok=True)
|
| 612 |
print(f"[INFO] Models directory: {models_dir}")
|
| 613 |
|
| 614 |
# Get the model name from the repository URL
|
| 615 |
-
|
|
|
|
| 616 |
|
| 617 |
# Create versioned model directory
|
| 618 |
version = get_next_model_version(models_dir, model_name)
|
| 619 |
model_base_dir = os.path.join(models_dir, model_name)
|
| 620 |
model_version_dir = os.path.join(model_base_dir, f"v{version}")
|
| 621 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
# Check if previous version exists and is valid
|
| 623 |
if version > 1:
|
| 624 |
prev_version_dir = os.path.join(model_base_dir, f"v{version-1}")
|
|
@@ -627,21 +658,53 @@ async def download_model_files():
|
|
| 627 |
model_path = prev_version_dir
|
| 628 |
state.is_model_loaded = True
|
| 629 |
else:
|
| 630 |
-
#
|
| 631 |
os.makedirs(model_version_dir, exist_ok=True)
|
| 632 |
-
success = clone_repository(Settings.MODEL_REPO, model_version_dir)
|
| 633 |
-
if not success:
|
| 634 |
-
raise Exception("Failed to clone repository")
|
| 635 |
model_path = model_version_dir
|
| 636 |
-
print(f"[INFO] Successfully cloned model to {model_path}")
|
| 637 |
else:
|
| 638 |
# First time download
|
| 639 |
os.makedirs(model_version_dir, exist_ok=True)
|
| 640 |
-
success = clone_repository(Settings.MODEL_REPO, model_version_dir)
|
| 641 |
-
if not success:
|
| 642 |
-
raise Exception("Failed to clone repository")
|
| 643 |
model_path = model_version_dir
|
| 644 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
|
| 646 |
# Set model paths in state
|
| 647 |
state.model_path = model_path
|
|
@@ -1091,7 +1154,7 @@ if __name__ == "__main__":
|
|
| 1091 |
print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
|
| 1092 |
|
| 1093 |
uvicorn.run(
|
| 1094 |
-
"
|
| 1095 |
host="0.0.0.0",
|
| 1096 |
port=port,
|
| 1097 |
reload=False
|
|
|
|
| 602 |
return all(f in model_files for f in required_files) and has_weights
|
| 603 |
|
| 604 |
async def download_model_files():
|
| 605 |
+
"""Downloads the model files using Hugging Face Hub API"""
|
| 606 |
try:
|
| 607 |
print(f"[INFO] Processing model from {Settings.MODEL_REPO}...")
|
| 608 |
|
| 609 |
+
# Install required packages if not present
|
| 610 |
+
required_packages = ["huggingface_hub", "requests", "tqdm"]
|
| 611 |
+
for package in required_packages:
|
| 612 |
+
try:
|
| 613 |
+
__import__(package)
|
| 614 |
+
except ImportError:
|
| 615 |
+
print(f"[INFO] Installing {package}...")
|
| 616 |
+
import subprocess
|
| 617 |
+
subprocess.check_call(["pip", "install", package])
|
| 618 |
+
|
| 619 |
+
from huggingface_hub import hf_hub_download, snapshot_download, HfFolder
|
| 620 |
+
import requests
|
| 621 |
+
from tqdm import tqdm
|
| 622 |
+
|
| 623 |
# Create models directory
|
| 624 |
models_dir = os.path.join(os.getcwd(), "models")
|
| 625 |
os.makedirs(models_dir, exist_ok=True)
|
| 626 |
print(f"[INFO] Models directory: {models_dir}")
|
| 627 |
|
| 628 |
# Get the model name from the repository URL
|
| 629 |
+
repo_id = "/".join(Settings.MODEL_REPO.split('/')[-2:]) # e.g., "facebook/opt-125m"
|
| 630 |
+
model_name = repo_id.split('/')[-1]
|
| 631 |
|
| 632 |
# Create versioned model directory
|
| 633 |
version = get_next_model_version(models_dir, model_name)
|
| 634 |
model_base_dir = os.path.join(models_dir, model_name)
|
| 635 |
model_version_dir = os.path.join(model_base_dir, f"v{version}")
|
| 636 |
|
| 637 |
+
# Function to download file with progress bar
|
| 638 |
+
def download_file(url, filename):
|
| 639 |
+
response = requests.get(url, stream=True)
|
| 640 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 641 |
+
|
| 642 |
+
with open(filename, 'wb') as f, tqdm(
|
| 643 |
+
desc=os.path.basename(filename),
|
| 644 |
+
total=total_size,
|
| 645 |
+
unit='iB',
|
| 646 |
+
unit_scale=True,
|
| 647 |
+
unit_divisor=1024,
|
| 648 |
+
) as pbar:
|
| 649 |
+
for data in response.iter_content(chunk_size=1024):
|
| 650 |
+
size = f.write(data)
|
| 651 |
+
pbar.update(size)
|
| 652 |
+
|
| 653 |
# Check if previous version exists and is valid
|
| 654 |
if version > 1:
|
| 655 |
prev_version_dir = os.path.join(model_base_dir, f"v{version-1}")
|
|
|
|
| 658 |
model_path = prev_version_dir
|
| 659 |
state.is_model_loaded = True
|
| 660 |
else:
|
| 661 |
+
# Download new version
|
| 662 |
os.makedirs(model_version_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
| 663 |
model_path = model_version_dir
|
|
|
|
| 664 |
else:
|
| 665 |
# First time download
|
| 666 |
os.makedirs(model_version_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
| 667 |
model_path = model_version_dir
|
| 668 |
+
|
| 669 |
+
if not state.is_model_loaded:
|
| 670 |
+
try:
|
| 671 |
+
print(f"[INFO] Downloading model files from {repo_id}...")
|
| 672 |
+
|
| 673 |
+
# First download config and other small files
|
| 674 |
+
config_files = ["config.json", "tokenizer_config.json", "vocab.json", "generation_config.json"]
|
| 675 |
+
for filename in config_files:
|
| 676 |
+
try:
|
| 677 |
+
file_path = hf_hub_download(
|
| 678 |
+
repo_id=repo_id,
|
| 679 |
+
filename=filename,
|
| 680 |
+
local_dir=model_path,
|
| 681 |
+
force_download=True
|
| 682 |
+
)
|
| 683 |
+
print(f"[INFO] Downloaded {filename}")
|
| 684 |
+
except Exception as e:
|
| 685 |
+
print(f"[WARN] Could not download {filename}: {str(e)}")
|
| 686 |
+
|
| 687 |
+
# Then download the model weights
|
| 688 |
+
print("[INFO] Downloading model weights (this may take a while)...")
|
| 689 |
+
for weight_file in ["pytorch_model.bin", "model.safetensors"]:
|
| 690 |
+
try:
|
| 691 |
+
file_path = hf_hub_download(
|
| 692 |
+
repo_id=repo_id,
|
| 693 |
+
filename=weight_file,
|
| 694 |
+
local_dir=model_path,
|
| 695 |
+
force_download=True
|
| 696 |
+
)
|
| 697 |
+
print(f"[INFO] Successfully downloaded {weight_file}")
|
| 698 |
+
break # Stop after first successful weight file download
|
| 699 |
+
except Exception as e:
|
| 700 |
+
print(f"[WARN] Could not download {weight_file}: {str(e)}")
|
| 701 |
+
continue
|
| 702 |
+
|
| 703 |
+
print(f"[INFO] All files downloaded to {model_path}")
|
| 704 |
+
state.is_model_loaded = True
|
| 705 |
+
|
| 706 |
+
except Exception as e:
|
| 707 |
+
raise Exception(f"Failed to download model files: {str(e)}")
|
| 708 |
|
| 709 |
# Set model paths in state
|
| 710 |
state.model_path = model_path
|
|
|
|
| 1154 |
print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
|
| 1155 |
|
| 1156 |
uvicorn.run(
|
| 1157 |
+
"controller_server_new:app",
|
| 1158 |
host="0.0.0.0",
|
| 1159 |
port=port,
|
| 1160 |
reload=False
|