forouzanfallah commited on
Commit
36d5aef
·
verified ·
1 Parent(s): bfa3575

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -25
app.py CHANGED
@@ -25,7 +25,6 @@ TARGET_PER_PERSON = 20
25
  CONTACT_EMAIL = "ffallah@asu.edu"
26
 
27
  # --- Paths ---
28
- CAPTIONS_JSON_PATH = os.environ.get("CAPTIONS_JSON_PATH", "data/captions.json")
29
 
30
  GT_MASKED_DIR = "data/gt_b" # Image 1
31
  GT_UNMASKED_DIR = "data/adc_b" # Image 2
@@ -147,42 +146,70 @@ def load_image(path: str) -> Image.Image:
147
  except Exception:
148
  return Image.new("RGB", (256, 256), color="gray")
149
 
150
- def load_dataset(captions_path: str, gt_masked_dir: str, gt_unmasked_dir: str, sr_dir: str, original_dir: str, image_5_dir: str) -> List[Sample]:
151
- if not os.path.exists(captions_path):
152
- print(f"Captions file not found at {captions_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  return []
154
 
155
- with open(captions_path, "r", encoding="utf-8") as f:
156
- try:
157
- captions_data = json.load(f)
158
- except Exception:
159
- print("Failed to parse captions JSON.")
160
- return []
161
 
162
  samples: List[Sample] = []
163
- for item in captions_data:
164
- base_filename = item.get("image")
165
- if not base_filename:
166
- continue
167
-
168
  sample_id = os.path.splitext(base_filename)[0]
 
169
  paths = {
170
- "masked": os.path.join(gt_masked_dir, base_filename),
171
  "unmasked": os.path.join(gt_unmasked_dir, base_filename),
172
- "sr": os.path.join(sr_dir, base_filename),
173
  "original": os.path.join(original_dir, base_filename),
174
- "img5": os.path.join(image_5_dir, base_filename)
175
  }
176
 
177
- # If strict enforcement required, require all five files to exist.
178
  if STRICT_ENFORCEMENT:
179
  if not all(os.path.exists(p) for p in paths.values()):
180
  missing = [k for k, v in paths.items() if not os.path.exists(v)]
181
- print(f"Skipping {base_filename}: Missing in folders {missing}")
182
  continue
183
 
184
- # In non-strict mode, it's okay to include samples even if some files missing;
185
- # we will supply placeholders at load time.
186
  samples.append(
187
  Sample(
188
  sample_id=sample_id,
@@ -190,12 +217,13 @@ def load_dataset(captions_path: str, gt_masked_dir: str, gt_unmasked_dir: str, s
190
  unmasked_gt_path=paths["unmasked"],
191
  sr_path=paths["sr"],
192
  original_path=paths["original"],
193
- image_5_path=paths["img5"]
194
  )
195
  )
196
 
197
  return samples
198
 
 
199
  # ----------------------
200
  # Progress & results I/O
201
  # ----------------------
@@ -258,7 +286,7 @@ def start_or_resume(name: str, email: str):
258
  raise gr.Error("Please enter your name and email to begin.")
259
 
260
  ensure_paths()
261
- samples = load_dataset(CAPTIONS_JSON_PATH, GT_MASKED_DIR, GT_UNMASKED_DIR, SR_DIR, ORIGINAL_DIR, IMAGE_5_DIR)
262
 
263
  if not samples:
264
  raise gr.Error("No images found. Please check dataset configuration.")
@@ -716,7 +744,7 @@ if __name__ == "__main__":
716
  print(f"Error reading from HF: {e}")
717
 
718
  ensure_paths()
719
- _ = load_dataset(CAPTIONS_JSON_PATH, GT_MASKED_DIR, GT_UNMASKED_DIR, SR_DIR, ORIGINAL_DIR, IMAGE_5_DIR)
720
 
721
  print("✅ Launching app.")
722
  demo.queue()
 
25
  CONTACT_EMAIL = "ffallah@asu.edu"
26
 
27
  # --- Paths ---
 
28
 
29
  GT_MASKED_DIR = "data/gt_b" # Image 1
30
  GT_UNMASKED_DIR = "data/adc_b" # Image 2
 
146
  except Exception:
147
  return Image.new("RGB", (256, 256), color="gray")
148
 
149
+ def load_dataset(
150
+ gt_masked_dir: str,
151
+ gt_unmasked_dir: str,
152
+ sr_dir: str,
153
+ original_dir: str,
154
+ image_5_dir: str,
155
+ ) -> List[Sample]:
156
+ """
157
+ Build samples only from the 5 folders.
158
+ Each folder should have the same filenames.
159
+ Example layout:
160
+ data/gt_b/xxx.png
161
+ data/adc_b/xxx.png
162
+ data/sr_b/xxx.png
163
+ data/lr_b/xxx.png
164
+ data/see_b/xxx.png
165
+ """
166
+
167
+ def list_images(dir_path: str) -> set:
168
+ if not os.path.isdir(dir_path):
169
+ print(f"Warning: directory not found: {dir_path}")
170
+ return set()
171
+ files = []
172
+ for f in os.listdir(dir_path):
173
+ f_lower = f.lower()
174
+ if f_lower.endswith((".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp")):
175
+ files.append(f)
176
+ return set(files)
177
+
178
+ masked_files = list_images(gt_masked_dir)
179
+ unmasked_files = list_images(gt_unmasked_dir)
180
+ sr_files = list_images(sr_dir)
181
+ orig_files = list_images(original_dir)
182
+ img5_files = list_images(image_5_dir)
183
+
184
+ # Common filenames present in ALL 5 folders
185
+ common_files = masked_files & unmasked_files & sr_files & orig_files & img5_files
186
+
187
+ if not common_files:
188
+ print("No common image files found in all 5 folders.")
189
  return []
190
 
191
+ # Optional: simple debug info
192
+ print(f"Found {len(common_files)} common images.")
 
 
 
 
193
 
194
  samples: List[Sample] = []
195
+ for base_filename in sorted(common_files):
 
 
 
 
196
  sample_id = os.path.splitext(base_filename)[0]
197
+
198
  paths = {
199
+ "masked": os.path.join(gt_masked_dir, base_filename),
200
  "unmasked": os.path.join(gt_unmasked_dir, base_filename),
201
+ "sr": os.path.join(sr_dir, base_filename),
202
  "original": os.path.join(original_dir, base_filename),
203
+ "img5": os.path.join(image_5_dir, base_filename),
204
  }
205
 
206
+ # If STRICT_ENFORCEMENT is True, skip if any file missing
207
  if STRICT_ENFORCEMENT:
208
  if not all(os.path.exists(p) for p in paths.values()):
209
  missing = [k for k, v in paths.items() if not os.path.exists(v)]
210
+ print(f"Skipping {base_filename}: missing in folders {missing}")
211
  continue
212
 
 
 
213
  samples.append(
214
  Sample(
215
  sample_id=sample_id,
 
217
  unmasked_gt_path=paths["unmasked"],
218
  sr_path=paths["sr"],
219
  original_path=paths["original"],
220
+ image_5_path=paths["img5"],
221
  )
222
  )
223
 
224
  return samples
225
 
226
+
227
  # ----------------------
228
  # Progress & results I/O
229
  # ----------------------
 
286
  raise gr.Error("Please enter your name and email to begin.")
287
 
288
  ensure_paths()
289
+ samples = load_dataset(GT_MASKED_DIR, GT_UNMASKED_DIR, SR_DIR, ORIGINAL_DIR, IMAGE_5_DIR)
290
 
291
  if not samples:
292
  raise gr.Error("No images found. Please check dataset configuration.")
 
744
  print(f"Error reading from HF: {e}")
745
 
746
  ensure_paths()
747
+ _ = load_dataset(GT_MASKED_DIR, GT_UNMASKED_DIR, SR_DIR, ORIGINAL_DIR, IMAGE_5_DIR)
748
 
749
  print("✅ Launching app.")
750
  demo.queue()