Tan Zi Xu commited on
Commit
695209d
·
1 Parent(s): 52fa700

gdrive integration

Browse files
Files changed (5) hide show
  1. cloud/storage_gdrive.py +98 -51
  2. infer.py +7 -18
  3. requirements.txt +1 -0
  4. storage.py +23 -12
  5. ui/sidebar.py +469 -13
cloud/storage_gdrive.py CHANGED
@@ -1,76 +1,123 @@
1
  # cloud/storage_gdrive.py
2
  import io
3
- import os
4
- import mimetypes
5
- from typing import Iterable, Optional
6
 
7
  from google.oauth2 import service_account as gsa
 
8
  from googleapiclient.discovery import build
9
  from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload
10
 
11
  from storage import BlobStore, BlobInfo
12
 
13
- _SCOPES = ["https://www.googleapis.com/auth/drive"]
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class GDriveStore(BlobStore):
16
- def __init__(self, folder_id: str, creds_json_path: str = "", service_account_json: str = ""):
17
- if not folder_id:
18
- raise ValueError("GDriveStore requires folder_id")
19
- self.folder_id = folder_id
 
 
 
 
20
 
21
- if service_account_json:
22
- creds = gsa.Credentials.from_service_account_file(service_account_json, scopes=_SCOPES)
 
 
 
 
23
  else:
24
- if not creds_json_path:
25
- raise ValueError("Provide GDRIVE_CREDENTIALS_JSON (OAuth) or GDRIVE_SERVICE_ACCOUNT_JSON")
26
- creds = gsa.Credentials.from_service_account_file(creds_json_path, scopes=_SCOPES)
27
-
28
- self.svc = build("drive", "v3", credentials=creds, cache_discovery=False)
29
 
30
- def _query(self, extra_q: str = ""):
31
- q = f"'{self.folder_id}' in parents and trashed=false"
32
- if extra_q:
33
- q += f" and {extra_q}"
34
- return q
35
 
36
- def list(self, prefix: str = "", recursive: bool = False) -> Iterable[BlobInfo]:
 
37
  page_token = None
38
- name_filter = f"name contains '{prefix}'" if prefix else ""
39
  while True:
40
  resp = self.svc.files().list(
41
- q=self._query(name_filter),
42
- fields="nextPageToken, files(id, name, size, modifiedTime, mimeType)",
43
  pageToken=page_token,
44
  corpora="allDrives",
45
  includeItemsFromAllDrives=True,
46
  supportsAllDrives=True,
47
  ).execute()
48
  for f in resp.get("files", []):
49
- yield BlobInfo(
50
- f["name"],
51
- size=int(f.get("size", 0)) if f.get("size") else None,
 
52
  modified=f.get("modifiedTime"),
53
- is_dir=(f["mimeType"] == "application/vnd.google-apps.folder"),
54
  )
55
  page_token = resp.get("nextPageToken")
56
  if not page_token:
57
  break
58
 
59
- def _find_id(self, name: str) -> Optional[str]:
60
- resp = self.svc.files().list(
61
- q=self._query(f"name='{name}'"),
62
- fields="files(id, name)",
63
- corpora="allDrives",
64
- includeItemsFromAllDrives=True,
65
- supportsAllDrives=True,
66
- ).execute()
67
- files = resp.get("files", [])
68
- return files[0]["id"] if files else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def read_bytes(self, key: str) -> bytes:
71
- fid = self._find_id(key)
72
- if not fid:
73
- raise FileNotFoundError(key)
74
  req = self.svc.files().get_media(fileId=fid, supportsAllDrives=True)
75
  buf = io.BytesIO()
76
  downloader = MediaIoBaseDownload(buf, req)
@@ -80,26 +127,26 @@ class GDriveStore(BlobStore):
80
  return buf.getvalue()
81
 
82
  def write_bytes(self, key: str, data: bytes, content_type: Optional[str] = None) -> None:
 
 
83
  media = MediaIoBaseUpload(io.BytesIO(data), mimetype=content_type or "application/octet-stream", resumable=False)
84
- fid = self._find_id(key)
85
- meta = {"name": key, "parents": [self.folder_id]}
86
- if fid:
87
  self.svc.files().update(fileId=fid, media_body=media, supportsAllDrives=True).execute()
88
- else:
 
89
  self.svc.files().create(body=meta, media_body=media, supportsAllDrives=True, fields="id").execute()
90
 
91
  def head(self, key: str) -> BlobInfo:
92
- fid = self._find_id(key)
93
- if not fid:
94
- raise FileNotFoundError(key)
95
  f = self.svc.files().get(
96
  fileId=fid,
97
  fields="id,name,size,modifiedTime,mimeType",
98
- supportsAllDrives=True
99
  ).execute()
100
  return BlobInfo(
101
- f["name"],
102
  size=int(f.get("size", 0)) if f.get("size") else None,
103
  modified=f.get("modifiedTime"),
104
- is_dir=(f["mimeType"] == "application/vnd.google-apps.folder"),
105
  )
 
1
  # cloud/storage_gdrive.py
2
  import io
3
+ from typing import Iterable, Optional, List
4
+ from dataclasses import dataclass
 
5
 
6
  from google.oauth2 import service_account as gsa
7
+ from google.oauth2.credentials import Credentials as UserCreds
8
  from googleapiclient.discovery import build
9
  from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload
10
 
11
  from storage import BlobStore, BlobInfo
12
 
13
+ # Read-only scopes (sufficient for listing + downloading images)
14
+ _SCOPES = [
15
+ "https://www.googleapis.com/auth/drive.readonly",
16
+ "https://www.googleapis.com/auth/drive.metadata.readonly",
17
+ ]
18
+
19
+ @dataclass
20
+ class _FileRow:
21
+ id: str
22
+ name: str
23
+ size: Optional[int]
24
+ modified: Optional[str]
25
+ mime: str
26
 
27
  class GDriveStore(BlobStore):
28
+ def __init__(
29
+ self,
30
+ folder_id: Optional[str] = None,
31
+ *,
32
+ credentials: Optional[UserCreds] = None, # user OAuth (device flow)
33
+ service_account_json: Optional[str] = None, # SA JSON (optional alt)
34
+ ):
35
+ self.root_folder_id = folder_id or "root"
36
 
37
+ if credentials is not None:
38
+ self.creds = credentials
39
+ elif service_account_json:
40
+ self.creds = gsa.Credentials.from_service_account_file(
41
+ service_account_json, scopes=_SCOPES
42
+ )
43
  else:
44
+ raise ValueError(
45
+ "GDriveStore requires user OAuth credentials or a service_account_json path."
46
+ )
 
 
47
 
48
+ self.svc = build("drive", "v3", credentials=self.creds, cache_discovery=False)
 
 
 
 
49
 
50
+ # ---------------- internal helpers ----------------
51
+ def _iter_files(self, q: str) -> Iterable[_FileRow]:
52
  page_token = None
 
53
  while True:
54
  resp = self.svc.files().list(
55
+ q=q,
56
+ fields="nextPageToken, files(id,name,size,modifiedTime,mimeType)",
57
  pageToken=page_token,
58
  corpora="allDrives",
59
  includeItemsFromAllDrives=True,
60
  supportsAllDrives=True,
61
  ).execute()
62
  for f in resp.get("files", []):
63
+ yield _FileRow(
64
+ id=f["id"],
65
+ name=f["name"],
66
+ size=int(f["size"]) if f.get("size") else None,
67
  modified=f.get("modifiedTime"),
68
+ mime=f.get("mimeType", ""),
69
  )
70
  page_token = resp.get("nextPageToken")
71
  if not page_token:
72
  break
73
 
74
+ def _children_of(self, folder_id: str, name_filter: str = "") -> Iterable[_FileRow]:
75
+ q = f"'{folder_id}' in parents and trashed=false"
76
+ if name_filter:
77
+ # Drive query uses single quotes; no special characters expected in 'contains'
78
+ q += f" and name contains '{name_filter}'"
79
+ yield from self._iter_files(q)
80
+
81
+ def _resolve_path(self, prefix_path: str) -> str:
82
+ """
83
+ Resolve a 'folder1/folder2' path starting from self.root_folder_id and return the final folder ID.
84
+ If any component is missing, returns self.root_folder_id (so listing won't crash).
85
+ """
86
+ if not prefix_path or prefix_path.strip() == "":
87
+ return self.root_folder_id
88
+ cur = self.root_folder_id
89
+ parts: List[str] = [p for p in prefix_path.split("/") if p]
90
+ for seg in parts:
91
+ # find a child folder with this name
92
+ q = f"'{cur}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and name='{seg}'"
93
+ match = next(self._iter_files(q), None)
94
+ if not match:
95
+ # folder segment not found; fall back to current (best effort)
96
+ return cur
97
+ cur = match.id
98
+ return cur
99
+
100
+ def _extract_id(self, key: str) -> str:
101
+ # Accept "id|name" or bare "id"
102
+ return key.split("|", 1)[0]
103
+
104
+ # ---------------- BlobStore implementation ----------------
105
+ def list(self, prefix: str = "", recursive: bool = False) -> Iterable[BlobInfo]:
106
+ # Treat prefix as a folder path (like S3). Final path component is a folder.
107
+ folder_id = self._resolve_path(prefix)
108
+ for f in self._children_of(folder_id):
109
+ is_dir = (f.mime == "application/vnd.google-apps.folder")
110
+ # Use "id|name" so the UI can display a friendly name but we keep a unique key
111
+ key = f"{f.id}|{f.name}"
112
+ yield BlobInfo(
113
+ key=key,
114
+ size=f.size,
115
+ modified=f.modified,
116
+ is_dir=is_dir,
117
+ )
118
 
119
  def read_bytes(self, key: str) -> bytes:
120
+ fid = self._extract_id(key)
 
 
121
  req = self.svc.files().get_media(fileId=fid, supportsAllDrives=True)
122
  buf = io.BytesIO()
123
  downloader = MediaIoBaseDownload(buf, req)
 
127
  return buf.getvalue()
128
 
129
  def write_bytes(self, key: str, data: bytes, content_type: Optional[str] = None) -> None:
130
+ # Writes are not used by your app, but kept for interface parity.
131
+ fid = self._extract_id(key)
132
  media = MediaIoBaseUpload(io.BytesIO(data), mimetype=content_type or "application/octet-stream", resumable=False)
133
+ # If fid refers to an existing file, update; otherwise create in root.
134
+ try:
 
135
  self.svc.files().update(fileId=fid, media_body=media, supportsAllDrives=True).execute()
136
+ except Exception:
137
+ meta = {"name": key.split("|",1)[-1], "parents": [self.root_folder_id]}
138
  self.svc.files().create(body=meta, media_body=media, supportsAllDrives=True, fields="id").execute()
139
 
140
  def head(self, key: str) -> BlobInfo:
141
+ fid = self._extract_id(key)
 
 
142
  f = self.svc.files().get(
143
  fileId=fid,
144
  fields="id,name,size,modifiedTime,mimeType",
145
+ supportsAllDrives=True,
146
  ).execute()
147
  return BlobInfo(
148
+ key=f"{f['id']}|{f['name']}",
149
  size=int(f.get("size", 0)) if f.get("size") else None,
150
  modified=f.get("modifiedTime"),
151
+ is_dir=(f.get("mimeType") == "application/vnd.google-apps.folder"),
152
  )
infer.py CHANGED
@@ -13,7 +13,7 @@ from torchcam.methods import GradCAM
13
  from torchcam.utils import overlay_mask
14
 
15
  # Track last resolution for Diagnostics
16
- LAST_WEIGHT_SOURCE = None # one of: local | s3 | gdrive | hf_manifest | hf_env
17
  LAST_WEIGHT_DETAIL = "" # key/repo/filename etc
18
  LAST_WEIGHT_PATH = "" # final local path used
19
 
@@ -241,14 +241,14 @@ def _resolve_ckpt(ckpt_path: str, model_name: str, models_root: str) -> str:
241
 
242
  # Order: depends on mode
243
  if mode == "off":
244
- order = ["local", "hf_manifest", "hf_env"]
245
  elif mode == "prefer_cloud":
246
- order = ["local", "cloud", "cloud_auto", "hf_manifest", "hf_env"]
247
  elif mode == "prefer_hf":
248
- order = ["local", "hf_manifest", "hf_env", "cloud", "cloud_auto"]
249
  else: # auto (per-allow-list)
250
- order = ["local", "hf_manifest", "hf_env", "cloud", "cloud_auto"] if (model_name not in (os.getenv("CLOUD_WEIGHTS_ALLOW",""))) else \
251
- ["local", "cloud", "cloud_auto", "hf_manifest", "hf_env"]
252
 
253
  # Try sources in decided order
254
  for src in order:
@@ -287,21 +287,10 @@ def _resolve_ckpt(ckpt_path: str, model_name: str, models_root: str) -> str:
287
  except Exception as e:
288
  _log(f"Manifest read error: {e}")
289
 
290
- if src == "hf_env":
291
- repo = os.getenv("HF_MODEL_REPO")
292
- if repo:
293
- fname = os.getenv("HF_MODEL_FILE", os.path.basename(ckpt_path) or "best.pth")
294
- _log(f"Downloading from env → repo={repo} file={fname}")
295
- path = hf_hub_download(repo_id=repo, filename=fname, token=os.getenv("HF_TOKEN"))
296
- LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH = "hf_env", f"{repo}:{fname}", path
297
- _log(f"Downloaded to cache: {path}")
298
- return path
299
-
300
  # If we got here, nothing worked
301
  parts = [f"local={'yes' if avail_local else 'no'}",
302
  f"cloud={'yes' if cloud_exists else 'no'} key={cloud_key or '-'} err={cloud_err or '-'}",
303
- f"hf_manifest={'present' if (Path(models_root)/'manifest.json').exists() else 'absent'}",
304
- f"hf_env={'yes' if os.getenv('HF_MODEL_REPO') else 'no'}"]
305
  raise FileNotFoundError("Could not resolve weights. Probes: " + ", ".join(parts))
306
 
307
  #--- main loading function---
 
13
  from torchcam.utils import overlay_mask
14
 
15
  # Track last resolution for Diagnostics
16
+ LAST_WEIGHT_SOURCE = None # one of: local | s3 | gdrive | hf_manifest
17
  LAST_WEIGHT_DETAIL = "" # key/repo/filename etc
18
  LAST_WEIGHT_PATH = "" # final local path used
19
 
 
241
 
242
  # Order: depends on mode
243
  if mode == "off":
244
+ order = ["local", "hf_manifest"]
245
  elif mode == "prefer_cloud":
246
+ order = ["local", "cloud", "cloud_auto", "hf_manifest"]
247
  elif mode == "prefer_hf":
248
+ order = ["local", "hf_manifest", "cloud", "cloud_auto"]
249
  else: # auto (per-allow-list)
250
+ order = ["local", "cloud", "cloud_auto", "hf_manifest"] if _cloud_mode_for(model_name) == "prefer_cloud" \
251
+ else ["local", "hf_manifest", "cloud", "cloud_auto"]
252
 
253
  # Try sources in decided order
254
  for src in order:
 
287
  except Exception as e:
288
  _log(f"Manifest read error: {e}")
289
 
 
 
 
 
 
 
 
 
 
 
290
  # If we got here, nothing worked
291
  parts = [f"local={'yes' if avail_local else 'no'}",
292
  f"cloud={'yes' if cloud_exists else 'no'} key={cloud_key or '-'} err={cloud_err or '-'}",
293
+ f"hf_manifest={'present' if (Path(models_root)/'manifest.json').exists() else 'absent'}"]
 
294
  raise FileNotFoundError("Could not resolve weights. Probes: " + ", ".join(parts))
295
 
296
  #--- main loading function---
requirements.txt CHANGED
@@ -13,4 +13,5 @@ google-api-python-client>=2.137.0
13
  google-auth>=2.34.0
14
  google-auth-httplib2>=0.2.0
15
  google-auth-oauthlib>=1.2.0
 
16
  matplotlib>=3.8
 
13
  google-auth>=2.34.0
14
  google-auth-httplib2>=0.2.0
15
  google-auth-oauthlib>=1.2.0
16
+ requests>=2.31
17
  matplotlib>=3.8
storage.py CHANGED
@@ -1,8 +1,7 @@
1
  # storage.py
2
  from __future__ import annotations
3
 
4
- import io
5
- import os
6
  from dataclasses import dataclass
7
  from pathlib import Path
8
  from typing import Iterable, Optional
@@ -42,19 +41,31 @@ class BlobStore:
42
 
43
  # ---------- Local FS baseline ----------
44
  class LocalStore(BlobStore):
45
- def __init__(self, root: Path | None = None):
46
- import os
47
- from pathlib import Path as _P
48
- import tempfile
49
- # 1) APP_DATA_DIR wins if set
 
 
 
 
50
  if root:
51
- self.root = _P(root)
52
  elif os.getenv("APP_DATA_DIR"):
53
- self.root = _P(os.getenv("APP_DATA_DIR"))
54
  else:
55
- # 2) Always-writable on HF Spaces / most containers
56
- self.root = _P(tempfile.gettempdir()) / "label_assistant"
57
- self.root.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
58
 
59
  def _p(self, key: str) -> Path:
60
  return self.root / key
 
1
  # storage.py
2
  from __future__ import annotations
3
 
4
+ import io, tempfile, os
 
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
  from typing import Iterable, Optional
 
41
 
42
  # ---------- Local FS baseline ----------
43
  class LocalStore(BlobStore):
44
+ def __init__(self, root: str | Path | None = None):
45
+ """
46
+ Always land on a writable directory.
47
+ Priority:
48
+ 1) explicit root (if provided)
49
+ 2) APP_DATA_DIR (if set)
50
+ 3) /tmp/label_assistant (always writable on Spaces/containers)
51
+ """
52
+ # choose base
53
  if root:
54
+ base = Path(root)
55
  elif os.getenv("APP_DATA_DIR"):
56
+ base = Path(os.getenv("APP_DATA_DIR"))
57
  else:
58
+ base = Path(tempfile.gettempdir()) / "label_assistant"
59
+
60
+ # ensure it exists; if it fails, force /tmp
61
+ try:
62
+ base.mkdir(parents=True, exist_ok=True)
63
+ except Exception:
64
+ base = Path(tempfile.gettempdir()) / "label_assistant"
65
+ base.mkdir(parents=True, exist_ok=True)
66
+
67
+ self.root = base
68
+ print(f"[LocalStore] using {self.root}", flush=True)
69
 
70
  def _p(self, key: str) -> Path:
71
  return self.root / key
ui/sidebar.py CHANGED
@@ -1,10 +1,15 @@
1
  # ui/sidebar.py
2
- import os, io, time
3
  from pathlib import Path
4
  import streamlit as st
5
  from PIL import Image
6
- from storage import get_store_from_env, LocalStore
7
  from cloud.storage_s3 import S3Store
 
 
 
 
 
8
 
9
  # ------------- small logger -------------
10
  def _log_s3(msg: str):
@@ -12,7 +17,7 @@ def _log_s3(msg: str):
12
  st.session_state.setdefault("aws_logs", [])
13
  st.session_state.aws_logs.append(f"[{ts}] {msg}")
14
 
15
- def _logs_panel():
16
  with st.sidebar.expander("S3 logs", expanded=False):
17
  logs = st.session_state.get("aws_logs", [])
18
  if not logs:
@@ -25,22 +30,57 @@ def _logs_panel():
25
  def build_sidebar_datasource(BASE_DIR: Path):
26
  st.sidebar.markdown("---")
27
  st.sidebar.subheader("Data source")
28
- src_choice = st.sidebar.radio("Load from", ["Local", "AWS S3", "Google Drive"], index=0, horizontal=False)
 
 
 
 
 
29
 
30
  if src_choice == "Local":
 
31
  store = LocalStore()
32
- _logs_panel()
33
  return store, src_choice
34
 
35
  if src_choice == "AWS S3":
36
- store = _s3_connect_ui() # includes connected banner + logs + browser
37
  return store, src_choice
38
 
39
- # Google Drive use your existing get_store_from_env path
40
- os.environ.setdefault("BLOB_BACKEND", "gdrive")
41
- store = get_store_from_env("gdrive")
42
- _logs_panel()
43
- return store, src_choice
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  # ------------- S3 connection UI -------------
@@ -109,9 +149,9 @@ def _s3_connect_ui():
109
  _remote_browser_ui(store)
110
  else:
111
  st.info("Enter your S3 details and click **Connect**.")
112
- _logs_panel()
113
  # When not connected, return harmless LocalStore so the rest of the app doesn't break
114
- return store if st.session_state.get("aws_store_ok") else LocalStore(Path.cwd()/ "data")
115
 
116
  # ------------- S3 browser -------------
117
  def _remote_browser_ui(STORE: S3Store):
@@ -190,6 +230,422 @@ def _remote_browser_ui(STORE: S3Store):
190
  st.sidebar.info("No images found for this prefix. Try a different filter or check your Prefix above.")
191
  _log_s3(f"No images found for '{list_prefix}'")
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  # ------------- Loaded images sidebar -------------
194
  def show_loaded_images_sidebar():
195
  st.sidebar.markdown("---")
 
1
  # ui/sidebar.py
2
+ import os, io, time, json, requests
3
  from pathlib import Path
4
  import streamlit as st
5
  from PIL import Image
6
+ from storage import LocalStore
7
  from cloud.storage_s3 import S3Store
8
+ from google.oauth2.credentials import Credentials
9
+ from cloud.storage_gdrive import GDriveStore
10
+ from googleapiclient.discovery import build
11
+ from google_auth_oauthlib.flow import Flow
12
+ from streamlit.components.v1 import html
13
 
14
  # ------------- small logger -------------
15
  def _log_s3(msg: str):
 
17
  st.session_state.setdefault("aws_logs", [])
18
  st.session_state.aws_logs.append(f"[{ts}] {msg}")
19
 
20
+ def _s3_logs_panel():
21
  with st.sidebar.expander("S3 logs", expanded=False):
22
  logs = st.session_state.get("aws_logs", [])
23
  if not logs:
 
30
  def build_sidebar_datasource(BASE_DIR: Path):
31
  st.sidebar.markdown("---")
32
  st.sidebar.subheader("Data source")
33
+ src_choice = st.sidebar.radio(
34
+ "Load from",
35
+ ["Local", "AWS S3", "Google Drive"],
36
+ index=0,
37
+ horizontal=False
38
+ )
39
 
40
  if src_choice == "Local":
41
+ os.environ.setdefault("APP_DATA_DIR", "/tmp/label_assistant") # never /app
42
  store = LocalStore()
 
43
  return store, src_choice
44
 
45
  if src_choice == "AWS S3":
46
+ store = _s3_connect_ui() # includes logs + browser
47
  return store, src_choice
48
 
49
+ if src_choice == "Google Drive":
50
+ with st.sidebar:
51
+ # (0) If a redirect URL was staged, perform it immediately (same tab).
52
+ redir = st.session_state.pop("_oauth_redirect", None)
53
+ if redir:
54
+ html(f'<script>window.location.replace("{redir}");</script>', height=0)
55
+ st.stop()
56
+
57
+ # (1) Handle callback first (if we just came back from Google)
58
+ creds = _get_pkce_creds_from_session() or _handle_google_callback_pkce()
59
+
60
+ # (2) Not signed in yet → start PKCE and force a reliable redirect
61
+ if not creds:
62
+ # Generate the auth URL and show it as a clickable link
63
+ url = _start_google_oauth_pkce()
64
+ if url:
65
+ st.link_button("Sign in with Google", url, use_container_width=True)
66
+ st.caption("After granting access, you'll be redirected back here automatically.")
67
+ os.environ.setdefault("APP_DATA_DIR", "/tmp/label_assistant")
68
+ return LocalStore(), src_choice
69
+
70
+ # (3) Show whoami + sign-out once signed in
71
+ email, who = _drive_whoami(creds)
72
+ st.success(f"Signed in as **{who}** · {email}")
73
+ if st.button("Sign out of Google", use_container_width=True):
74
+ for k in ("g_creds", "g_state", "g_client_config", "_oauth_redirect"):
75
+ st.session_state.pop(k, None)
76
+ st.rerun()
77
+
78
+ # (4) Use Drive store once signed in
79
+ store = GDriveStore(credentials=creds)
80
+ _gdrive_browser_ui(store)
81
+ return store, src_choice
82
+
83
+ return LocalStore(), "Local"
84
 
85
 
86
  # ------------- S3 connection UI -------------
 
149
  _remote_browser_ui(store)
150
  else:
151
  st.info("Enter your S3 details and click **Connect**.")
152
+ _s3_logs_panel()
153
  # When not connected, return harmless LocalStore so the rest of the app doesn't break
154
+ return store if st.session_state.get("aws_store_ok") else LocalStore()
155
 
156
  # ------------- S3 browser -------------
157
  def _remote_browser_ui(STORE: S3Store):
 
230
  st.sidebar.info("No images found for this prefix. Try a different filter or check your Prefix above.")
231
  _log_s3(f"No images found for '{list_prefix}'")
232
 
233
+ # ---------- Google Drive connection + browser ----------
234
+ # ---------- Google Drive (PKCE OAuth) ----------
235
+ DRIVE_SCOPES = [
236
+ "https://www.googleapis.com/auth/drive.readonly",
237
+ "https://www.googleapis.com/auth/drive.metadata.readonly",
238
+ "https://www.googleapis.com/auth/drive.file",
239
+ ]
240
+
241
+ def _start_google_oauth_pkce():
242
+ client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
243
+ client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
244
+ redirect_uri = os.getenv("GOOGLE_OAUTH_REDIRECT_URI")
245
+ if not (client_id and client_secret and redirect_uri):
246
+ st.error("Missing GOOGLE_OAUTH_CLIENT_ID / GOOGLE_OAUTH_CLIENT_SECRET / GOOGLE_OAUTH_REDIRECT_URI.")
247
+ return None
248
+
249
+ client_config = {
250
+ "web": {
251
+ "client_id": client_id,
252
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
253
+ "token_uri": "https://oauth2.googleapis.com/token",
254
+ "client_secret": client_secret,
255
+ "redirect_uris": [redirect_uri],
256
+ }
257
+ }
258
+ flow = Flow.from_client_config(client_config, scopes=DRIVE_SCOPES, redirect_uri=redirect_uri)
259
+ # prompt=consent ensures refresh_token on repeated logins during testing
260
+ auth_url, state = flow.authorization_url(access_type="offline", include_granted_scopes="true", prompt="consent")
261
+ st.session_state["g_state"] = state
262
+ st.session_state["g_client_config"] = client_config
263
+ return auth_url
264
+
265
+ def _handle_google_callback_pkce():
266
+ params = st.query_params
267
+ if "code" not in params or "state" not in params:
268
+ return None
269
+
270
+ code = params.get("code")
271
+ state = params.get("state")
272
+ if isinstance(code, list): code = code[0]
273
+ if isinstance(state, list): state = state[0]
274
+
275
+ # Enforce state only if we have one saved (same-session case).
276
+ expected = st.session_state.get("g_state")
277
+ if expected and state != expected:
278
+ st.error("OAuth state mismatch. Please try signing in again.")
279
+ return None
280
+
281
+ # Try to get client config from session; if missing (new session), rebuild from ENV.
282
+ client_config = st.session_state.get("g_client_config")
283
+ if not client_config:
284
+ cid = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
285
+ csec = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
286
+ ruri = os.getenv("GOOGLE_OAUTH_REDIRECT_URI")
287
+ if not (cid and csec and ruri):
288
+ st.error("Missing GOOGLE_OAUTH_CLIENT_ID/SECRET/REDIRECT_URI while handling callback.")
289
+ return None
290
+ client_config = {
291
+ "web": {
292
+ "client_id": cid,
293
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
294
+ "token_uri": "https://oauth2.googleapis.com/token",
295
+ "client_secret": csec,
296
+ "redirect_uris": [ruri],
297
+ }
298
+ }
299
+
300
+ redirect_uri = client_config["web"]["redirect_uris"][0]
301
+ flow = Flow.from_client_config(client_config, scopes=DRIVE_SCOPES, redirect_uri=redirect_uri)
302
+ try:
303
+ flow.fetch_token(code=code)
304
+ except Exception as e:
305
+ st.error(f"OAuth token exchange failed: {e}")
306
+ return None
307
+
308
+ creds = flow.credentials
309
+ st.session_state["g_creds"] = {
310
+ "token": creds.token,
311
+ "refresh_token": creds.refresh_token,
312
+ "token_uri": creds.token_uri,
313
+ "client_id": creds.client_id,
314
+ "client_secret": creds.client_secret,
315
+ "scopes": creds.scopes,
316
+ }
317
+
318
+ # Clear query params so we don’t reprocess the callback on every rerun
319
+ try:
320
+ st.query_params.clear()
321
+ except Exception:
322
+ pass
323
+ return creds
324
+
325
+ def _get_pkce_creds_from_session():
326
+ data = st.session_state.get("g_creds")
327
+ return Credentials(**data) if data else None
328
+
329
+ def _gdrive_display_name(k: str) -> str:
330
+ return k.split("|", 1)[1] if "|" in k else k
331
+
332
+ def _gdev_start():
333
+ cid = os.getenv("GOOGLE_OAUTH_CLIENT_ID", "").strip()
334
+ if not cid:
335
+ st.error("Missing GOOGLE_OAUTH_CLIENT_ID env. Set it in your Space/local .env.")
336
+ return
337
+ if not cid.endswith(".apps.googleusercontent.com"):
338
+ st.warning("GOOGLE_OAUTH_CLIENT_ID doesn't look like a valid OAuth client ID (…apps.googleusercontent.com). Double-check the value.")
339
+
340
+ try:
341
+ resp = requests.post(
342
+ "https://oauth2.googleapis.com/device/code",
343
+ data={"client_id": cid, "scope": DRIVE_SCOPES},
344
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
345
+ timeout=15,
346
+ )
347
+ except Exception as e:
348
+ st.error(f"Google device-code start failed: network error: {e}")
349
+ return
350
+
351
+ if resp.status_code != 200:
352
+ # Show the exact JSON/text so you know what's wrong
353
+ try:
354
+ err = resp.json()
355
+ except Exception:
356
+ err = resp.text
357
+ st.error(f"Device-code start failed: {resp.status_code} {err}")
358
+ st.session_state["gdev_error"] = err
359
+ return
360
+
361
+ data = resp.json()
362
+ st.session_state["gdev"] = data
363
+ st.session_state["gdev_started_at"] = time.time()
364
+ st.session_state["gcreds"] = None
365
+
366
+ def _gdev_poll():
367
+ if "gdev" not in st.session_state:
368
+ return
369
+ cid = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
370
+ csec = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET", "")
371
+ dev = st.session_state["gdev"]
372
+ try:
373
+ resp = requests.post(
374
+ "https://oauth2.googleapis.com/token",
375
+ data={
376
+ "client_id": cid,
377
+ "client_secret": csec, # optional for desktop/device, harmless if set
378
+ "device_code": dev["device_code"],
379
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
380
+ },
381
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
382
+ timeout=15,
383
+ )
384
+ except Exception as e:
385
+ st.error(f"Google token poll failed: network error: {e}")
386
+ return
387
+
388
+ if resp.status_code == 400:
389
+ err = (resp.json() or {}).get("error")
390
+ # expected while user hasn’t approved yet
391
+ if err in ("authorization_pending", "slow_down"):
392
+ return
393
+ st.error(f"Google auth error: {err}")
394
+ st.session_state["gdev_error"] = resp.text
395
+ return
396
+
397
+ if resp.status_code != 200:
398
+ st.error(f"Google token poll failed: {resp.status_code} {resp.text}")
399
+ st.session_state["gdev_error"] = resp.text
400
+ return
401
+
402
+ tok = resp.json()
403
+ creds = Credentials(
404
+ token=tok["access_token"],
405
+ refresh_token=tok.get("refresh_token"),
406
+ token_uri="https://oauth2.googleapis.com/token",
407
+ client_id=cid,
408
+ client_secret=csec or None,
409
+ scopes=DRIVE_SCOPES.split(),
410
+ )
411
+ st.session_state["gcreds"] = {
412
+ "token": creds.token,
413
+ "refresh_token": creds.refresh_token,
414
+ "token_uri": creds.token_uri,
415
+ "client_id": creds.client_id,
416
+ "client_secret": creds.client_secret,
417
+ "scopes": creds.scopes,
418
+ }
419
+
420
+ def _drive_whoami(creds):
421
+ """Return (email, name) for the signed-in account."""
422
+ try:
423
+ svc = build("drive", "v3", credentials=creds, cache_discovery=False)
424
+ me = svc.about().get(fields="user(emailAddress, displayName)").execute()
425
+ u = me.get("user", {}) or {}
426
+ return u.get("emailAddress") or "unknown", u.get("displayName") or "unknown"
427
+ except Exception as e:
428
+ return None, f"whoami failed: {e}"
429
+
430
+ def _drive_debug_list_root(creds, folder_id="root"):
431
+ """Return count and sample names from the current root to verify listing works."""
432
+ try:
433
+ svc = build("drive", "v3", credentials=creds, cache_discovery=False)
434
+ q = f"'{folder_id or 'root'}' in parents and trashed=false"
435
+ resp = svc.files().list(
436
+ q=q,
437
+ fields="files(id,name,mimeType), nextPageToken",
438
+ corpora="allDrives",
439
+ includeItemsFromAllDrives=True,
440
+ supportsAllDrives=True,
441
+ pageSize=20
442
+ ).execute()
443
+ files = resp.get("files", [])
444
+ names = [f["name"] for f in files[:10]]
445
+ return len(files), names
446
+ except Exception as e:
447
+ return None, f"list failed: {e}"
448
+
449
+ def _drive_login_block():
450
+ """UI block for Google Drive sign-in via Device Code."""
451
+ creds_data = st.session_state.get("gcreds")
452
+ if creds_data:
453
+ st.success("Connected to Google Drive.")
454
+ if st.button("Sign out of Google", use_container_width=True):
455
+ st.session_state.pop("gcreds", None)
456
+ st.session_state.pop("gdev", None)
457
+ st.rerun()
458
+ return Credentials(**creds_data)
459
+
460
+ dev = st.session_state.get("gdev")
461
+ if not dev:
462
+ if st.button("Sign in with Google", use_container_width=True):
463
+ # clear a prior error so it doesn’t linger after a new attempt
464
+ st.session_state.pop("gdev_error", None)
465
+ _gdev_start()
466
+ st.rerun()
467
+
468
+ # ← show the last Google response even if device-code init failed
469
+ ge = st.session_state.get("gdev_error")
470
+ if ge:
471
+ st.sidebar.markdown("**Google auth error (last response):**")
472
+ st.sidebar.code(json.dumps(ge, indent=2) if isinstance(ge, dict) else str(ge))
473
+
474
+ st.caption("You’ll use a short code on accounts.google.com/device to grant Drive read-only access.")
475
+ return None
476
+
477
+ # Show code + link and poll
478
+ st.info(f"Go to **{dev['verification_url']}** and enter this code:\n\n**{dev['user_code']}**")
479
+ col1, col2 = st.columns([1,1])
480
+ with col1:
481
+ if st.button("I’ve approved", use_container_width=True):
482
+ _gdev_poll()
483
+ st.rerun()
484
+ with col2:
485
+ if st.button("Cancel", use_container_width=True):
486
+ st.session_state.pop("gdev", None)
487
+ st.rerun()
488
+
489
+ # Optional: auto-poll every few seconds
490
+ st.caption("Waiting for approval…")
491
+ _gdev_poll()
492
+
493
+ if "gdev_error" in st.session_state:
494
+ st.sidebar.caption("Google response:")
495
+ st.sidebar.code(json.dumps(st.session_state["gdev_error"], indent=2) if isinstance(st.session_state["gdev_error"], dict) else str(st.session_state["gdev_error"]))
496
+
497
+ return None
498
+
499
+ def _gdrive_connect_ui():
500
+ st.sidebar.markdown("**Connect to your Google Drive**")
501
+ with st.sidebar.expander("Drive connection", expanded=True):
502
+ # Two modes: Service Account (recommended on Spaces) OR OAuth Device Code.
503
+ mode = st.radio("Auth mode", ["Service Account (JSON)", "OAuth Device Code"], horizontal=False, key="gd_mode")
504
+
505
+ if mode == "Service Account (JSON)":
506
+ sa_json = st.text_input(
507
+ "Path to service account JSON (mounted/secret)",
508
+ value=st.session_state.get("gd_sa_json", os.getenv("GDRIVE_SERVICE_ACCOUNT_JSON","")),
509
+ key="gd_sa_json"
510
+ )
511
+ folder_id = st.text_input(
512
+ "Root Folder ID (share this folder with your service account email)",
513
+ value=st.session_state.get("gd_folder", os.getenv("GDRIVE_FOLDER_ID","")),
514
+ key="gd_folder"
515
+ )
516
+
517
+ c1, c2 = st.columns([1,1])
518
+ if c1.button("Connect", use_container_width=True):
519
+ try:
520
+ store = GDriveStore(service_account_json=sa_json or None, root_folder_id=folder_id or None)
521
+ # probe
522
+ _ = next(iter(store.list(prefix="", recursive=False)), None)
523
+ st.session_state.gd_store = {
524
+ "ok": True, "err": None, "mode": "sa",
525
+ "sa_json": sa_json, "folder": folder_id
526
+ }
527
+ st.success("Connected to Google Drive (Service Account).")
528
+ st.rerun()
529
+ except Exception as e:
530
+ st.session_state.gd_store = {"ok": False, "err": str(e)}
531
+ st.error(f"Connection failed: {e}")
532
+
533
+ if c2.button("Sign out", use_container_width=True):
534
+ st.session_state.pop("gd_store", None)
535
+ st.rerun()
536
+
537
+ else: # OAuth Device Code (user account)
538
+ client_id = st.text_input("OAuth Client ID", value=st.session_state.get("gd_cid",""), key="gd_cid")
539
+ client_secret = st.text_input("OAuth Client Secret", value=st.session_state.get("gd_csec",""), type="password", key="gd_csec")
540
+ folder_id = st.text_input("Root Folder ID (optional)", value=st.session_state.get("gd_folder",""), key="gd_folder")
541
+
542
+ c1, c2 = st.columns([1,1])
543
+ if c1.button("Connect", use_container_width=True):
544
+ try:
545
+ store = GDriveStore(
546
+ client_id=client_id or None,
547
+ client_secret=client_secret or None,
548
+ root_folder_id=folder_id or None,
549
+ oauth_mode="device" # GDriveStore should implement device-code flow
550
+ )
551
+ _ = next(iter(store.list(prefix="", recursive=False)), None)
552
+ st.session_state.gd_store = {
553
+ "ok": True, "err": None, "mode": "oauth",
554
+ "cid": client_id, "csec": client_secret, "folder": folder_id
555
+ }
556
+ st.success("Connected to Google Drive (OAuth).")
557
+ st.rerun()
558
+ except Exception as e:
559
+ st.session_state.gd_store = {"ok": False, "err": str(e)}
560
+ st.error(f"Connection failed: {e}")
561
+
562
+ if c2.button("Sign out", use_container_width=True):
563
+ st.session_state.pop("gd_store", None)
564
+ st.rerun()
565
+
566
+ # Connected banner + browser
567
+ gd = st.session_state.get("gd_store")
568
+ if gd and gd.get("ok"):
569
+ st.sidebar.success(f"Connected to Drive (mode={gd.get('mode')}).")
570
+ if gd["mode"] == "sa":
571
+ store = GDriveStore(service_account_json=gd.get("sa_json") or None,
572
+ root_folder_id=gd.get("folder") or None)
573
+ else:
574
+ store = GDriveStore(client_id=gd.get("cid") or None,
575
+ client_secret=gd.get("csec") or None,
576
+ root_folder_id=gd.get("folder") or None,
577
+ oauth_mode="device")
578
+ _gdrive_browser_ui(store)
579
+ return store
580
+
581
+ st.info("Enter your Drive credentials and click **Connect**.")
582
+ return LocalStore()
583
+
584
+ def _gdrive_browser_ui(STORE: GDriveStore):
585
+ st.sidebar.markdown("---")
586
+ st.sidebar.subheader("Browse Drive images")
587
+
588
+ list_prefix = st.sidebar.text_input("Subfolder filter (path / ‘folder1/folder2’)", value=st.session_state.get("gd_prefix",""))
589
+ c1, c2 = st.sidebar.columns(2)
590
+
591
+ if c1.button("List", use_container_width=True):
592
+ st.session_state.gd_prefix = list_prefix
593
+ items = []
594
+ try:
595
+ raw = list(STORE.list(prefix=list_prefix, recursive=True))
596
+ for b in raw:
597
+ if b.is_dir:
598
+ continue
599
+ if b.key.lower().endswith((".jpg",".jpeg",".png",".bmp",".webp")):
600
+ items.append(b)
601
+ except Exception as e:
602
+ st.sidebar.error(f"Drive list failed: {e}")
603
+ items = []
604
+ st.session_state.gd_list = items
605
+
606
+ if c2.button("Load all listed", use_container_width=True):
607
+ loaded = 0
608
+ for b in st.session_state.get("gd_list", []):
609
+ try:
610
+ data = STORE.read_bytes(b.key)
611
+ im = Image.open(io.BytesIO(data)).convert("RGB")
612
+ key = f"gdrive://{b.key}"
613
+ st.session_state.images[key] = {
614
+ "key": key, "name": _gdrive_display_name(b.key), "pil": im,
615
+ "boxes": [], "preds": [], "user_labels": [], "actions": [], "canvas_json": None
616
+ }
617
+ if st.session_state.active_key is None:
618
+ st.session_state.active_key = key
619
+ loaded += 1
620
+ except Exception as e:
621
+ st.sidebar.warning(f"Load failed for {b.key}: {e}")
622
+ if loaded:
623
+ st.sidebar.success(f"Loaded {loaded} image(s)")
624
+ st.rerun()
625
+
626
+ remote_items = st.session_state.get("gd_list", [])
627
+ if remote_items:
628
+ st.sidebar.caption(f"{len(remote_items)} image(s) listed under '{list_prefix or '(root)'}'")
629
+ for b in remote_items:
630
+ cols = st.sidebar.columns([0.7, 0.3])
631
+ cols[0].caption(f"🖼️ {_gdrive_display_name(b.key)}")
632
+ if cols[1].button("Load", key=f"gd_load_{b.key}"):
633
+ try:
634
+ data = STORE.read_bytes(b.key)
635
+ im = Image.open(io.BytesIO(data)).convert("RGB")
636
+ key = f"gdrive://{b.key}"
637
+ st.session_state.images[key] = {
638
+ "key": key, "name": _gdrive_display_name(b.key), "pil": im,
639
+ "boxes": [], "preds": [], "user_labels": [], "actions": [], "canvas_json": None
640
+ }
641
+ if st.session_state.active_key is None:
642
+ st.session_state.active_key = key
643
+ st.rerun()
644
+ except Exception as e:
645
+ st.sidebar.error(f"Load failed: {e}")
646
+ else:
647
+ st.sidebar.info("No images found. Try another subfolder.")
648
+
649
  # ------------- Loaded images sidebar -------------
650
  def show_loaded_images_sidebar():
651
  st.sidebar.markdown("---")