Fred808 commited on
Commit
cc80389
·
verified ·
1 Parent(s): 1eb637f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -12
app.py CHANGED
@@ -602,23 +602,54 @@ def check_existing_model(model_path: str) -> bool:
602
  return all(f in model_files for f in required_files) and has_weights
603
 
604
  async def download_model_files():
605
- """Downloads the model files using git clone from Hugging Face repository"""
606
  try:
607
  print(f"[INFO] Processing model from {Settings.MODEL_REPO}...")
608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  # Create models directory
610
  models_dir = os.path.join(os.getcwd(), "models")
611
  os.makedirs(models_dir, exist_ok=True)
612
  print(f"[INFO] Models directory: {models_dir}")
613
 
614
  # Get the model name from the repository URL
615
- model_name = Settings.MODEL_REPO.split('/')[-1]
 
616
 
617
  # Create versioned model directory
618
  version = get_next_model_version(models_dir, model_name)
619
  model_base_dir = os.path.join(models_dir, model_name)
620
  model_version_dir = os.path.join(model_base_dir, f"v{version}")
621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
  # Check if previous version exists and is valid
623
  if version > 1:
624
  prev_version_dir = os.path.join(model_base_dir, f"v{version-1}")
@@ -627,21 +658,53 @@ async def download_model_files():
627
  model_path = prev_version_dir
628
  state.is_model_loaded = True
629
  else:
630
- # Clone new version if previous is invalid or incomplete
631
  os.makedirs(model_version_dir, exist_ok=True)
632
- success = clone_repository(Settings.MODEL_REPO, model_version_dir)
633
- if not success:
634
- raise Exception("Failed to clone repository")
635
  model_path = model_version_dir
636
- print(f"[INFO] Successfully cloned model to {model_path}")
637
  else:
638
  # First time download
639
  os.makedirs(model_version_dir, exist_ok=True)
640
- success = clone_repository(Settings.MODEL_REPO, model_version_dir)
641
- if not success:
642
- raise Exception("Failed to clone repository")
643
  model_path = model_version_dir
644
- print(f"[INFO] Successfully cloned model to {model_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
 
646
  # Set model paths in state
647
  state.model_path = model_path
@@ -1091,7 +1154,7 @@ if __name__ == "__main__":
1091
  print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
1092
 
1093
  uvicorn.run(
1094
- "app:app",
1095
  host="0.0.0.0",
1096
  port=port,
1097
  reload=False
 
602
  return all(f in model_files for f in required_files) and has_weights
603
 
604
  async def download_model_files():
605
+ """Downloads the model files using Hugging Face Hub API"""
606
  try:
607
  print(f"[INFO] Processing model from {Settings.MODEL_REPO}...")
608
 
609
+ # Install required packages if not present
610
+ required_packages = ["huggingface_hub", "requests", "tqdm"]
611
+ for package in required_packages:
612
+ try:
613
+ __import__(package)
614
+ except ImportError:
615
+ print(f"[INFO] Installing {package}...")
616
+ import subprocess
617
+ subprocess.check_call(["pip", "install", package])
618
+
619
+ from huggingface_hub import hf_hub_download, snapshot_download, HfFolder
620
+ import requests
621
+ from tqdm import tqdm
622
+
623
  # Create models directory
624
  models_dir = os.path.join(os.getcwd(), "models")
625
  os.makedirs(models_dir, exist_ok=True)
626
  print(f"[INFO] Models directory: {models_dir}")
627
 
628
  # Get the model name from the repository URL
629
+ repo_id = "/".join(Settings.MODEL_REPO.split('/')[-2:]) # e.g., "facebook/opt-125m"
630
+ model_name = repo_id.split('/')[-1]
631
 
632
  # Create versioned model directory
633
  version = get_next_model_version(models_dir, model_name)
634
  model_base_dir = os.path.join(models_dir, model_name)
635
  model_version_dir = os.path.join(model_base_dir, f"v{version}")
636
 
637
+ # Function to download file with progress bar
638
+ def download_file(url, filename):
639
+ response = requests.get(url, stream=True)
640
+ total_size = int(response.headers.get('content-length', 0))
641
+
642
+ with open(filename, 'wb') as f, tqdm(
643
+ desc=os.path.basename(filename),
644
+ total=total_size,
645
+ unit='iB',
646
+ unit_scale=True,
647
+ unit_divisor=1024,
648
+ ) as pbar:
649
+ for data in response.iter_content(chunk_size=1024):
650
+ size = f.write(data)
651
+ pbar.update(size)
652
+
653
  # Check if previous version exists and is valid
654
  if version > 1:
655
  prev_version_dir = os.path.join(model_base_dir, f"v{version-1}")
 
658
  model_path = prev_version_dir
659
  state.is_model_loaded = True
660
  else:
661
+ # Download new version
662
  os.makedirs(model_version_dir, exist_ok=True)
 
 
 
663
  model_path = model_version_dir
 
664
  else:
665
  # First time download
666
  os.makedirs(model_version_dir, exist_ok=True)
 
 
 
667
  model_path = model_version_dir
668
+
669
+ if not state.is_model_loaded:
670
+ try:
671
+ print(f"[INFO] Downloading model files from {repo_id}...")
672
+
673
+ # First download config and other small files
674
+ config_files = ["config.json", "tokenizer_config.json", "vocab.json", "generation_config.json"]
675
+ for filename in config_files:
676
+ try:
677
+ file_path = hf_hub_download(
678
+ repo_id=repo_id,
679
+ filename=filename,
680
+ local_dir=model_path,
681
+ force_download=True
682
+ )
683
+ print(f"[INFO] Downloaded {filename}")
684
+ except Exception as e:
685
+ print(f"[WARN] Could not download {filename}: {str(e)}")
686
+
687
+ # Then download the model weights
688
+ print("[INFO] Downloading model weights (this may take a while)...")
689
+ for weight_file in ["pytorch_model.bin", "model.safetensors"]:
690
+ try:
691
+ file_path = hf_hub_download(
692
+ repo_id=repo_id,
693
+ filename=weight_file,
694
+ local_dir=model_path,
695
+ force_download=True
696
+ )
697
+ print(f"[INFO] Successfully downloaded {weight_file}")
698
+ break # Stop after first successful weight file download
699
+ except Exception as e:
700
+ print(f"[WARN] Could not download {weight_file}: {str(e)}")
701
+ continue
702
+
703
+ print(f"[INFO] All files downloaded to {model_path}")
704
+ state.is_model_loaded = True
705
+
706
+ except Exception as e:
707
+ raise Exception(f"Failed to download model files: {str(e)}")
708
 
709
  # Set model paths in state
710
  state.model_path = model_path
 
1154
  print(f"[INFO] API Documentation available at http://localhost:{port}/docs")
1155
 
1156
  uvicorn.run(
1157
+ "controller_server_new:app",
1158
  host="0.0.0.0",
1159
  port=port,
1160
  reload=False