lyimo commited on
Commit
49328f3
·
verified ·
1 Parent(s): dbf80df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -25
app.py CHANGED
@@ -2,20 +2,19 @@ import os
2
  import requests
3
  from tqdm import tqdm
4
  import logging
 
5
 
6
- def download_sam_model(url="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
7
- save_dir="."):
8
  """
9
- Download the SAM model weights to the root directory with progress tracking.
10
-
11
- Args:
12
- url (str): URL of the model weights
13
- save_dir (str): Directory to save the downloaded file (defaults to current directory)
14
  """
 
 
15
  try:
16
- # Extract filename from URL
 
17
  filename = url.split('/')[-1]
18
- save_path = os.path.join(save_dir, filename)
19
 
20
  # Setup logging
21
  logging.basicConfig(level=logging.INFO)
@@ -23,15 +22,14 @@ def download_sam_model(url="https://dl.fbaipublicfiles.com/segment_anything/sam_
23
 
24
  # Check if file already exists
25
  if os.path.exists(save_path):
26
- logger.info(f"File {filename} already exists in {save_dir}")
27
- return save_path
28
 
29
- # Send a HEAD request to get the file size
30
  response = requests.head(url)
31
  file_size = int(response.headers.get('content-length', 0))
32
 
33
  # Download the file with progress bar
34
- logger.info(f"Downloading {filename} to root directory")
35
  response = requests.get(url, stream=True)
36
  progress = tqdm(total=file_size, unit='iB', unit_scale=True)
37
 
@@ -43,21 +41,19 @@ def download_sam_model(url="https://dl.fbaipublicfiles.com/segment_anything/sam_
43
  progress.close()
44
 
45
  if file_size != 0 and progress.n != file_size:
46
- logger.error("Error during download - incomplete file")
47
  raise Exception("Downloaded file size does not match expected size")
48
 
49
- logger.info(f"Successfully downloaded {filename}")
50
- return save_path
51
 
52
  except Exception as e:
53
  logger.error(f"Error downloading file: {str(e)}")
54
- raise
55
-
 
 
 
 
 
 
56
  if __name__ == "__main__":
57
- MODEL_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
58
-
59
- try:
60
- downloaded_path = download_sam_model(MODEL_URL)
61
- print(f"Model downloaded successfully to: {downloaded_path}")
62
- except Exception as e:
63
- print(f"Failed to download model: {str(e)}")
 
2
  import requests
3
  from tqdm import tqdm
4
  import logging
5
+ import gradio as gr
6
 
7
+ def download_sam_model():
 
8
  """
9
+ Download the SAM model weights to the Hugging Face Space repository root
 
 
 
 
10
  """
11
+ url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
12
+
13
  try:
14
+ # Get the repository root directory (same level as app.py)
15
+ root_dir = os.path.dirname(os.path.abspath(__file__))
16
  filename = url.split('/')[-1]
17
+ save_path = os.path.join(root_dir, filename)
18
 
19
  # Setup logging
20
  logging.basicConfig(level=logging.INFO)
 
22
 
23
  # Check if file already exists
24
  if os.path.exists(save_path):
25
+ return gr.Info("Model file already exists in the repository root")
 
26
 
27
+ # Send a HEAD request to get file size
28
  response = requests.head(url)
29
  file_size = int(response.headers.get('content-length', 0))
30
 
31
  # Download the file with progress bar
32
+ logger.info(f"Downloading {filename} to Space repository root")
33
  response = requests.get(url, stream=True)
34
  progress = tqdm(total=file_size, unit='iB', unit_scale=True)
35
 
 
41
  progress.close()
42
 
43
  if file_size != 0 and progress.n != file_size:
 
44
  raise Exception("Downloaded file size does not match expected size")
45
 
46
+ return gr.Info("Model downloaded successfully!")
 
47
 
48
  except Exception as e:
49
  logger.error(f"Error downloading file: {str(e)}")
50
+ return gr.Error(f"Failed to download model: {str(e)}")
51
+
52
+ # Create a simple Gradio interface
53
+ with gr.Blocks() as demo:
54
+ gr.Markdown("# Download SAM Model")
55
+ download_button = gr.Button("Download Model")
56
+ download_button.click(fn=download_sam_model)
57
+
58
  if __name__ == "__main__":
59
+ demo.launch()