lyimo commited on
Commit
dbf80df
·
verified ·
1 Parent(s): 2e98dab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+ import logging
5
+
6
+ def download_sam_model(url="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
7
+ save_dir="."):
8
+ """
9
+ Download the SAM model weights to the root directory with progress tracking.
10
+
11
+ Args:
12
+ url (str): URL of the model weights
13
+ save_dir (str): Directory to save the downloaded file (defaults to current directory)
14
+ """
15
+ try:
16
+ # Extract filename from URL
17
+ filename = url.split('/')[-1]
18
+ save_path = os.path.join(save_dir, filename)
19
+
20
+ # Setup logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Check if file already exists
25
+ if os.path.exists(save_path):
26
+ logger.info(f"File {filename} already exists in {save_dir}")
27
+ return save_path
28
+
29
+ # Send a HEAD request to get the file size
30
+ response = requests.head(url)
31
+ file_size = int(response.headers.get('content-length', 0))
32
+
33
+ # Download the file with progress bar
34
+ logger.info(f"Downloading {filename} to root directory")
35
+ response = requests.get(url, stream=True)
36
+ progress = tqdm(total=file_size, unit='iB', unit_scale=True)
37
+
38
+ with open(save_path, 'wb') as file:
39
+ for data in response.iter_content(chunk_size=1024):
40
+ progress.update(len(data))
41
+ file.write(data)
42
+
43
+ progress.close()
44
+
45
+ if file_size != 0 and progress.n != file_size:
46
+ logger.error("Error during download - incomplete file")
47
+ raise Exception("Downloaded file size does not match expected size")
48
+
49
+ logger.info(f"Successfully downloaded {filename}")
50
+ return save_path
51
+
52
+ except Exception as e:
53
+ logger.error(f"Error downloading file: {str(e)}")
54
+ raise
55
+
56
+ if __name__ == "__main__":
57
+ MODEL_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
58
+
59
+ try:
60
+ downloaded_path = download_sam_model(MODEL_URL)
61
+ print(f"Model downloaded successfully to: {downloaded_path}")
62
+ except Exception as e:
63
+ print(f"Failed to download model: {str(e)}")