mustafa2ak commited on
Commit
0e8476b
·
verified ·
1 Parent(s): e7ab977

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +28 -41
utils.py CHANGED
@@ -1,41 +1,28 @@
1
- import os, zipfile
2
- import numpy as np
3
- from PIL import Image
4
- import random
5
-
6
- # Use local path instead of /content
7
- EXTRACTION_PATH = "./dataset"
8
-
9
- def extract_dataset(zip_path="test.zip"):
10
- """Extracts test.zip into ./dataset/test"""
11
- if not os.path.exists(EXTRACTION_PATH):
12
- os.makedirs(EXTRACTION_PATH)
13
-
14
- with zipfile.ZipFile(zip_path, "r") as zip_ref:
15
- zip_ref.extractall(EXTRACTION_PATH)
16
-
17
- return os.path.join(EXTRACTION_PATH, "test")
18
-
19
- def filter_images_with_buildings(base_path):
20
- test_a = os.path.join(base_path, "A")
21
- test_b = os.path.join(base_path, "B")
22
- label = os.path.join(base_path, "label")
23
-
24
- image_files = sorted(os.listdir(test_a))
25
- valid_images = []
26
-
27
- for img_name in image_files:
28
- mask_path = os.path.join(label, img_name)
29
- mask = Image.open(mask_path).convert("L")
30
- mask_np = np.array(mask)
31
-
32
- if np.any(mask_np > 0): # keep only if mask has buildings
33
- valid_images.append({
34
- "image_a": os.path.join(test_a, img_name),
35
- "image_b": os.path.join(test_b, img_name),
36
- "mask": mask_path
37
- })
38
- return valid_images
39
-
40
- def get_random_valid_image(valid_images):
41
- return random.choice(valid_images)
 
1
+ import os
2
+ import zipfile
3
+ import shutil
4
+
5
+ def setup_dataset(zip_path="test.zip", extract_path="./dataset"):
6
+ """Extract dataset if not already extracted"""
7
+ dataset_path = os.path.join(extract_path, "test")
8
+
9
+ if not os.path.exists(dataset_path):
10
+ print("Extracting dataset...")
11
+ os.makedirs(extract_path, exist_ok=True)
12
+
13
+ try:
14
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
15
+ zip_ref.extractall(extract_path)
16
+ print("Dataset extracted successfully!")
17
+ except Exception as e:
18
+ print(f"Error extracting dataset: {e}")
19
+ return None
20
+
21
+ return dataset_path
22
+
23
+ def cleanup_cache():
24
+ """Clean up any temporary files"""
25
+ cache_dirs = ['__pycache__', '.ipynb_checkpoints']
26
+ for cache_dir in cache_dirs:
27
+ if os.path.exists(cache_dir):
28
+ shutil.rmtree(cache_dir)