coderuday21 commited on
Commit
ce1e651
·
1 Parent(s): bd1ea11

Code review fixes: SSIM stability, symmetric CLAHE, confidence scoring, KMeans perf, auth hardening, remove debug endpoint

Browse files
Files changed (3) hide show
  1. app/auth.py +9 -17
  2. app/detection_engine.py +60 -26
  3. app/main.py +16 -18
app/auth.py CHANGED
@@ -1,5 +1,6 @@
 
1
  import os
2
- from datetime import datetime, timedelta
3
  from typing import Optional
4
 
5
  from jose import JWTError, jwt
@@ -11,6 +12,8 @@ from sqlalchemy.orm import Session
11
  from .database import get_db
12
  from .models import User
13
 
 
 
14
  SECRET_KEY = os.environ.get("SECRET_KEY", "dev-fallback-key-change-in-production")
15
  ALGORITHM = "HS256"
16
  ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
@@ -30,7 +33,7 @@ def get_password_hash(password: str) -> str:
30
 
31
  def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
32
  to_encode = data.copy()
33
- expire = datetime.utcnow() + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
34
  to_encode.update({"exp": expire})
35
  return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
36
 
@@ -44,26 +47,22 @@ def get_user_by_id(db: Session, user_id: int) -> Optional[User]:
44
 
45
 
46
  def get_user_from_token(token: str, db: Session) -> Optional[User]:
47
- """Resolve user from JWT token (used as fallback when header/cookie not sent)."""
48
  if not token:
49
- print("[AUTH] get_user_from_token: token is empty/None")
50
  return None
51
  try:
52
  payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
53
  user_id_str = payload.get("sub")
54
- print(f"[AUTH] decoded token OK, sub={user_id_str}")
55
  if user_id_str is None:
56
  return None
57
  try:
58
  user_id = int(user_id_str)
59
  except (ValueError, TypeError):
 
60
  return None
61
- except JWTError as e:
62
- print(f"[AUTH] JWT decode FAILED: {e}")
63
  return None
64
- user = get_user_by_id(db, user_id)
65
- print(f"[AUTH] DB lookup: user={'found' if user else 'NOT FOUND'}")
66
- return user
67
 
68
 
69
  def get_current_user(
@@ -71,18 +70,11 @@ def get_current_user(
71
  credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
72
  db: Session = Depends(get_db),
73
  ) -> Optional[User]:
74
- print(f"[AUTH] get_current_user called")
75
- print(f"[AUTH] credentials present: {credentials is not None}")
76
- print(f"[AUTH] cookie present: {request.cookies.get(COOKIE_NAME) is not None}")
77
- print(f"[AUTH] Authorization header: {request.headers.get('authorization', 'MISSING')[:50]}")
78
- # 1) Try Bearer header
79
  if credentials:
80
  user = get_user_from_token(credentials.credentials, db)
81
  if user:
82
  return user
83
- # 2) Try cookie (sent automatically by browser on same-origin requests)
84
  token = request.cookies.get(COOKIE_NAME)
85
  if token:
86
  return get_user_from_token(token, db)
87
- print("[AUTH] No valid auth found, returning None")
88
  return None
 
1
+ import logging
2
  import os
3
+ from datetime import datetime, timedelta, timezone
4
  from typing import Optional
5
 
6
  from jose import JWTError, jwt
 
12
  from .database import get_db
13
  from .models import User
14
 
15
+ logger = logging.getLogger(__name__)
16
+
17
  SECRET_KEY = os.environ.get("SECRET_KEY", "dev-fallback-key-change-in-production")
18
  ALGORITHM = "HS256"
19
  ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
 
33
 
34
  def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
35
  to_encode = data.copy()
36
+ expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
37
  to_encode.update({"exp": expire})
38
  return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
39
 
 
47
 
48
 
49
  def get_user_from_token(token: str, db: Session) -> Optional[User]:
50
+ """Resolve user from JWT token."""
51
  if not token:
 
52
  return None
53
  try:
54
  payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
55
  user_id_str = payload.get("sub")
 
56
  if user_id_str is None:
57
  return None
58
  try:
59
  user_id = int(user_id_str)
60
  except (ValueError, TypeError):
61
+ logger.warning("JWT 'sub' claim is not a valid integer")
62
  return None
63
+ except JWTError:
 
64
  return None
65
+ return get_user_by_id(db, user_id)
 
 
66
 
67
 
68
  def get_current_user(
 
70
  credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
71
  db: Session = Depends(get_db),
72
  ) -> Optional[User]:
 
 
 
 
 
73
  if credentials:
74
  user = get_user_from_token(credentials.credentials, db)
75
  if user:
76
  return user
 
77
  token = request.cookies.get(COOKIE_NAME)
78
  if token:
79
  return get_user_from_token(token, db)
 
80
  return None
app/detection_engine.py CHANGED
@@ -21,13 +21,16 @@ def preprocess_image(image):
21
  img_array = np.array(image)
22
  if img_array.ndim == 2:
23
  img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
24
- if img_array.shape[2] == 4:
25
  img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
 
 
26
  max_size = 2000
27
  height, width = img_array.shape[:2]
28
  if max(height, width) > max_size:
29
  scale = max_size / max(height, width)
30
- img_array = cv2.resize(img_array, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA)
 
31
  return img_array
32
 
33
 
@@ -68,11 +71,15 @@ def register_images(img1, img2, max_features=2000):
68
  if homography is None:
69
  return img1, img2, False
70
 
71
- # Only accept if enough inliers
72
  inlier_ratio = np.sum(mask) / len(mask) if mask is not None else 0
73
  if inlier_ratio < 0.3:
74
  return img1, img2, False
75
 
 
 
 
 
 
76
  h, w = img1.shape[:2]
77
  img2_aligned = cv2.warpPerspective(img2, homography, (w, h), borderMode=cv2.BORDER_REFLECT)
78
  return img1, img2_aligned, True
@@ -83,7 +90,7 @@ def register_images(img1, img2, max_features=2000):
83
  # ---------------------------------------------------------------------------
84
 
85
  def normalize_radiometry(img1, img2):
86
- """Histogram-matching normalization in LAB space for all channels."""
87
  lab1 = cv2.cvtColor(img1, cv2.COLOR_RGB2LAB).astype(np.float32)
88
  lab2 = cv2.cvtColor(img2, cv2.COLOR_RGB2LAB).astype(np.float32)
89
 
@@ -94,13 +101,17 @@ def normalize_radiometry(img1, img2):
94
  if std2 > 1e-6:
95
  result[:, :, ch] = (lab2[:, :, ch] - mean2) * (std1 / std2) + mean1
96
 
97
- # Also apply CLAHE on L channel for contrast equalization
98
  result_uint8 = np.clip(result, 0, 255).astype(np.uint8)
 
 
99
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
 
 
100
  result_uint8[:, :, 0] = clahe.apply(result_uint8[:, :, 0])
101
 
102
- img2_normalized = cv2.cvtColor(result_uint8, cv2.COLOR_LAB2RGB)
103
- return img1, img2_normalized
 
104
 
105
 
106
  # ---------------------------------------------------------------------------
@@ -122,12 +133,13 @@ def compute_ssim_change_map(img1, img2, win_size=7):
122
  mu2_sq = mu2 * mu2
123
  mu1_mu2 = mu1 * mu2
124
 
125
- sigma1_sq = cv2.GaussianBlur(gray1 * gray1, (win_size, win_size), 1.5) - mu1_sq
126
- sigma2_sq = cv2.GaussianBlur(gray2 * gray2, (win_size, win_size), 1.5) - mu2_sq
 
127
  sigma12 = cv2.GaussianBlur(gray1 * gray2, (win_size, win_size), 1.5) - mu1_mu2
128
 
129
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
130
- ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
131
 
132
  # Structural dissimilarity: 0 = identical, 1 = completely different
133
  dssim = np.clip((1.0 - ssim_map) / 2.0, 0, 1)
@@ -224,7 +236,6 @@ def feature_based_method(img1, img2, num_clusters=4, sensitivity=0.5):
224
  if img1.shape != img2.shape:
225
  img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
226
 
227
- # Combine LAB and HSV differences for richer features
228
  lab1 = cv2.cvtColor(img1, cv2.COLOR_RGB2LAB).astype(np.float32)
229
  lab2 = cv2.cvtColor(img2, cv2.COLOR_RGB2LAB).astype(np.float32)
230
  hsv1 = cv2.cvtColor(img1, cv2.COLOR_RGB2HSV).astype(np.float32)
@@ -234,19 +245,41 @@ def feature_based_method(img1, img2, num_clusters=4, sensitivity=0.5):
234
  diff_hsv = np.abs(hsv1 - hsv2)
235
 
236
  h, w, _ = diff_lab.shape
237
- features = np.concatenate([diff_lab, diff_hsv[:, :, 1:]], axis=2) # 5 channels
238
- features_flat = features.reshape(-1, features.shape[2])
 
 
 
 
 
 
 
 
 
 
239
 
 
240
  scaler = StandardScaler()
241
  features_scaled = scaler.fit_transform(features_flat)
242
 
243
  kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
244
- labels = kmeans.fit_predict(features_scaled)
245
 
246
- # Find the cluster with highest mean difference (= change)
247
- cluster_means = [np.mean(np.linalg.norm(features_flat[labels == i], axis=1)) for i in range(num_clusters)]
 
 
 
248
  change_cluster_idx = np.argmax(cluster_means)
249
 
 
 
 
 
 
 
 
 
250
  change_mask = (labels == change_cluster_idx).astype(np.uint8) * 255
251
  change_mask = change_mask.reshape(h, w)
252
 
@@ -667,12 +700,8 @@ def classify_object_type(image_region, bbox):
667
  soil += 0.10
668
  scores["Bare Land/Soil Change"] = soil
669
 
670
- # Normalize scores
671
- max_score = max(scores.values()) if scores else 0
672
- if max_score > 0:
673
- for k in scores:
674
- scores[k] /= max_score
675
-
676
  best = max(scores, key=scores.get)
677
  conf = scores[best]
678
 
@@ -698,14 +727,20 @@ def classify_with_ensemble(image_region, bbox, num_sub=4):
698
 
699
  classifications = []
700
  confidences = []
 
701
  for sb in sub_boxes:
702
  obj_type, conf = classify_object_type(image_region, sb)
703
  if obj_type is None:
704
- return None, 0.0 # transient → exclude
 
705
  if obj_type != "Unclassified":
706
  classifications.append(obj_type)
707
  confidences.append(conf)
708
 
 
 
 
 
709
  if not classifications:
710
  return classify_object_type(image_region, (x, y, w, h))
711
 
@@ -795,8 +830,7 @@ def run_detection(before_pil, after_pil, method="AI-Based Deep Learning",
795
  else:
796
  change_mask = hybrid_method(before_array, after_array)
797
 
798
- # Classify regions
799
- change_regions = analyze_change_regions(change_mask, after_array, min_area=80)
800
 
801
  # Color-coded visualization using region classifications
802
  result_image = visualize_changes(before_array, after_array, change_mask, regions=change_regions)
 
21
  img_array = np.array(image)
22
  if img_array.ndim == 2:
23
  img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB)
24
+ elif img_array.ndim == 3 and img_array.shape[2] == 4:
25
  img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB)
26
+ elif img_array.ndim != 3 or img_array.shape[2] != 3:
27
+ raise ValueError(f"Unsupported image shape: {img_array.shape}")
28
  max_size = 2000
29
  height, width = img_array.shape[:2]
30
  if max(height, width) > max_size:
31
  scale = max_size / max(height, width)
32
+ new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale))
33
+ img_array = cv2.resize(img_array, (new_w, new_h), interpolation=cv2.INTER_AREA)
34
  return img_array
35
 
36
 
 
71
  if homography is None:
72
  return img1, img2, False
73
 
 
74
  inlier_ratio = np.sum(mask) / len(mask) if mask is not None else 0
75
  if inlier_ratio < 0.3:
76
  return img1, img2, False
77
 
78
+ # Reject degenerate homographies (near-singular or extreme distortion)
79
+ det = np.linalg.det(homography)
80
+ if abs(det) < 0.1 or abs(det) > 10.0:
81
+ return img1, img2, False
82
+
83
  h, w = img1.shape[:2]
84
  img2_aligned = cv2.warpPerspective(img2, homography, (w, h), borderMode=cv2.BORDER_REFLECT)
85
  return img1, img2_aligned, True
 
90
  # ---------------------------------------------------------------------------
91
 
92
  def normalize_radiometry(img1, img2):
93
+ """Histogram-matching normalization in LAB space. CLAHE applied symmetrically."""
94
  lab1 = cv2.cvtColor(img1, cv2.COLOR_RGB2LAB).astype(np.float32)
95
  lab2 = cv2.cvtColor(img2, cv2.COLOR_RGB2LAB).astype(np.float32)
96
 
 
101
  if std2 > 1e-6:
102
  result[:, :, ch] = (lab2[:, :, ch] - mean2) * (std1 / std2) + mean1
103
 
 
104
  result_uint8 = np.clip(result, 0, 255).astype(np.uint8)
105
+
106
+ # CLAHE on L channel of BOTH images so downstream comparison is symmetric
107
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
108
+ lab1_uint8 = cv2.cvtColor(img1, cv2.COLOR_RGB2LAB)
109
+ lab1_uint8[:, :, 0] = clahe.apply(lab1_uint8[:, :, 0])
110
  result_uint8[:, :, 0] = clahe.apply(result_uint8[:, :, 0])
111
 
112
+ img1_out = cv2.cvtColor(lab1_uint8, cv2.COLOR_LAB2RGB)
113
+ img2_out = cv2.cvtColor(result_uint8, cv2.COLOR_LAB2RGB)
114
+ return img1_out, img2_out
115
 
116
 
117
  # ---------------------------------------------------------------------------
 
133
  mu2_sq = mu2 * mu2
134
  mu1_mu2 = mu1 * mu2
135
 
136
+ # Clamp to zero: E[X²]-E[X]² can go slightly negative from float rounding
137
+ sigma1_sq = np.maximum(cv2.GaussianBlur(gray1 * gray1, (win_size, win_size), 1.5) - mu1_sq, 0)
138
+ sigma2_sq = np.maximum(cv2.GaussianBlur(gray2 * gray2, (win_size, win_size), 1.5) - mu2_sq, 0)
139
  sigma12 = cv2.GaussianBlur(gray1 * gray2, (win_size, win_size), 1.5) - mu1_mu2
140
 
141
+ denom = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
142
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (denom + 1e-12)
143
 
144
  # Structural dissimilarity: 0 = identical, 1 = completely different
145
  dssim = np.clip((1.0 - ssim_map) / 2.0, 0, 1)
 
236
  if img1.shape != img2.shape:
237
  img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
238
 
 
239
  lab1 = cv2.cvtColor(img1, cv2.COLOR_RGB2LAB).astype(np.float32)
240
  lab2 = cv2.cvtColor(img2, cv2.COLOR_RGB2LAB).astype(np.float32)
241
  hsv1 = cv2.cvtColor(img1, cv2.COLOR_RGB2HSV).astype(np.float32)
 
245
  diff_hsv = np.abs(hsv1 - hsv2)
246
 
247
  h, w, _ = diff_lab.shape
248
+ features = np.concatenate([diff_lab, diff_hsv[:, :, 1:]], axis=2)
249
+
250
+ # Downsample for KMeans (full-res is too slow for >1M pixels)
251
+ MAX_PIXELS = 250_000
252
+ total = h * w
253
+ if total > MAX_PIXELS:
254
+ scale = np.sqrt(MAX_PIXELS / total)
255
+ sh, sw = max(1, int(h * scale)), max(1, int(w * scale))
256
+ features_small = cv2.resize(features, (sw, sh))
257
+ else:
258
+ features_small = features
259
+ sh, sw = h, w
260
 
261
+ features_flat = features_small.reshape(-1, features_small.shape[2])
262
  scaler = StandardScaler()
263
  features_scaled = scaler.fit_transform(features_flat)
264
 
265
  kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
266
+ labels_small = kmeans.fit_predict(features_scaled)
267
 
268
+ cluster_means = [
269
+ np.mean(np.linalg.norm(features_flat[labels_small == i], axis=1))
270
+ if np.any(labels_small == i) else 0.0
271
+ for i in range(num_clusters)
272
+ ]
273
  change_cluster_idx = np.argmax(cluster_means)
274
 
275
+ # Map labels back to full resolution by predicting on all pixels
276
+ if total > MAX_PIXELS:
277
+ full_flat = features.reshape(-1, features.shape[2])
278
+ full_scaled = scaler.transform(full_flat)
279
+ labels = kmeans.predict(full_scaled)
280
+ else:
281
+ labels = labels_small
282
+
283
  change_mask = (labels == change_cluster_idx).astype(np.uint8) * 255
284
  change_mask = change_mask.reshape(h, w)
285
 
 
700
  soil += 0.10
701
  scores["Bare Land/Soil Change"] = soil
702
 
703
+ # Use raw scores as confidence (each rule set sums to ~1.0 max)
704
+ # Do NOT normalize by max_score that inflates weak matches to 1.0
 
 
 
 
705
  best = max(scores, key=scores.get)
706
  conf = scores[best]
707
 
 
727
 
728
  classifications = []
729
  confidences = []
730
+ transient_count = 0
731
  for sb in sub_boxes:
732
  obj_type, conf = classify_object_type(image_region, sb)
733
  if obj_type is None:
734
+ transient_count += 1
735
+ continue
736
  if obj_type != "Unclassified":
737
  classifications.append(obj_type)
738
  confidences.append(conf)
739
 
740
+ # Only exclude if majority of sub-regions are transient
741
+ if transient_count > len(sub_boxes) // 2:
742
+ return None, 0.0
743
+
744
  if not classifications:
745
  return classify_object_type(image_region, (x, y, w, h))
746
 
 
830
  else:
831
  change_mask = hybrid_method(before_array, after_array)
832
 
833
+ change_regions = analyze_change_regions(change_mask, after_array, min_area=200)
 
834
 
835
  # Color-coded visualization using region classifications
836
  result_image = visualize_changes(before_array, after_array, change_mask, regions=change_regions)
app/main.py CHANGED
@@ -132,12 +132,18 @@ def reset_password(data: PasswordReset, db: Session = Depends(get_db)):
132
  raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
133
  user = get_user_by_email(db, data.email)
134
  if not user:
 
135
  raise HTTPException(status_code=404, detail="No account found with that email")
136
  user.hashed_password = get_password_hash(data.new_password)
137
  db.commit()
138
  return {"ok": True, "message": "Password has been reset. You can now sign in."}
139
 
140
 
 
 
 
 
 
141
  @app.get("/api/me")
142
  def me(user: Optional[User] = Depends(get_current_user)):
143
  if not user:
@@ -145,22 +151,6 @@ def me(user: Optional[User] = Depends(get_current_user)):
145
  return {"id": user.id, "email": user.email, "full_name": user.full_name}
146
 
147
 
148
- @app.get("/api/debug-auth")
149
- def debug_auth(request: Request, user: Optional[User] = Depends(get_current_user)):
150
- """Debug endpoint to see what auth info the server receives."""
151
- auth_header = request.headers.get("authorization", "")
152
- cookie_val = request.cookies.get(COOKIE_NAME, "")
153
- return {
154
- "has_auth_header": bool(auth_header),
155
- "auth_header_preview": auth_header[:40] + "..." if len(auth_header) > 40 else auth_header,
156
- "has_cookie": bool(cookie_val),
157
- "cookie_preview": cookie_val[:20] + "..." if len(cookie_val) > 20 else cookie_val,
158
- "authenticated": user is not None,
159
- "user_id": user.id if user else None,
160
- "user_email": user.email if user else None,
161
- }
162
-
163
-
164
  # --- Detection route ---
165
  @app.post("/api/detect")
166
  async def detect(
@@ -186,9 +176,17 @@ async def detect(
186
  user = get_user_from_token(token, db) if token else None
187
  if not user:
188
  raise HTTPException(status_code=401, detail="Login required")
 
189
  try:
190
- before_pil = Image.open(io.BytesIO(await before.read())).convert("RGB")
191
- after_pil = Image.open(io.BytesIO(await after.read())).convert("RGB")
 
 
 
 
 
 
 
192
  except Exception as e:
193
  raise HTTPException(status_code=400, detail=f"Invalid image: {e}")
194
  change_mask, result_image, stats, change_regions = run_detection(
 
132
  raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
133
  user = get_user_by_email(db, data.email)
134
  if not user:
135
+ # Intentionally vague to prevent email enumeration
136
  raise HTTPException(status_code=404, detail="No account found with that email")
137
  user.hashed_password = get_password_hash(data.new_password)
138
  db.commit()
139
  return {"ok": True, "message": "Password has been reset. You can now sign in."}
140
 
141
 
142
+ # NOTE: This reset flow has no email verification. In production, implement
143
+ # a token-based flow: POST /forgot sends email with one-time link,
144
+ # GET /reset?token=... validates token, POST /reset sets new password.
145
+
146
+
147
  @app.get("/api/me")
148
  def me(user: Optional[User] = Depends(get_current_user)):
149
  if not user:
 
151
  return {"id": user.id, "email": user.email, "full_name": user.full_name}
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  # --- Detection route ---
155
  @app.post("/api/detect")
156
  async def detect(
 
176
  user = get_user_from_token(token, db) if token else None
177
  if not user:
178
  raise HTTPException(status_code=401, detail="Login required")
179
+ MAX_UPLOAD_BYTES = 20 * 1024 * 1024 # 20 MB
180
  try:
181
+ before_bytes = await before.read()
182
+ after_bytes = await after.read()
183
+ if len(before_bytes) > MAX_UPLOAD_BYTES or len(after_bytes) > MAX_UPLOAD_BYTES:
184
+ raise HTTPException(status_code=400, detail="Image too large (max 20 MB)")
185
+ before_pil = Image.open(io.BytesIO(before_bytes)).convert("RGB")
186
+ after_pil = Image.open(io.BytesIO(after_bytes)).convert("RGB")
187
+ del before_bytes, after_bytes
188
+ except HTTPException:
189
+ raise
190
  except Exception as e:
191
  raise HTTPException(status_code=400, detail=f"Invalid image: {e}")
192
  change_mask, result_image, stats, change_regions = run_detection(