Tan Zi Xu commited on
Commit ·
695209d
1
Parent(s): 52fa700
gdrive integration
Browse files- cloud/storage_gdrive.py +98 -51
- infer.py +7 -18
- requirements.txt +1 -0
- storage.py +23 -12
- ui/sidebar.py +469 -13
cloud/storage_gdrive.py
CHANGED
|
@@ -1,76 +1,123 @@
|
|
| 1 |
# cloud/storage_gdrive.py
|
| 2 |
import io
|
| 3 |
-
import
|
| 4 |
-
import
|
| 5 |
-
from typing import Iterable, Optional
|
| 6 |
|
| 7 |
from google.oauth2 import service_account as gsa
|
|
|
|
| 8 |
from googleapiclient.discovery import build
|
| 9 |
from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload
|
| 10 |
|
| 11 |
from storage import BlobStore, BlobInfo
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class GDriveStore(BlobStore):
|
| 16 |
-
def __init__(
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
if
|
| 22 |
-
creds =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
else:
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
self.svc = build("drive", "v3", credentials=creds, cache_discovery=False)
|
| 29 |
|
| 30 |
-
|
| 31 |
-
q = f"'{self.folder_id}' in parents and trashed=false"
|
| 32 |
-
if extra_q:
|
| 33 |
-
q += f" and {extra_q}"
|
| 34 |
-
return q
|
| 35 |
|
| 36 |
-
|
|
|
|
| 37 |
page_token = None
|
| 38 |
-
name_filter = f"name contains '{prefix}'" if prefix else ""
|
| 39 |
while True:
|
| 40 |
resp = self.svc.files().list(
|
| 41 |
-
q=
|
| 42 |
-
fields="nextPageToken, files(id,
|
| 43 |
pageToken=page_token,
|
| 44 |
corpora="allDrives",
|
| 45 |
includeItemsFromAllDrives=True,
|
| 46 |
supportsAllDrives=True,
|
| 47 |
).execute()
|
| 48 |
for f in resp.get("files", []):
|
| 49 |
-
yield
|
| 50 |
-
f["
|
| 51 |
-
|
|
|
|
| 52 |
modified=f.get("modifiedTime"),
|
| 53 |
-
|
| 54 |
)
|
| 55 |
page_token = resp.get("nextPageToken")
|
| 56 |
if not page_token:
|
| 57 |
break
|
| 58 |
|
| 59 |
-
def
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
def read_bytes(self, key: str) -> bytes:
|
| 71 |
-
fid = self.
|
| 72 |
-
if not fid:
|
| 73 |
-
raise FileNotFoundError(key)
|
| 74 |
req = self.svc.files().get_media(fileId=fid, supportsAllDrives=True)
|
| 75 |
buf = io.BytesIO()
|
| 76 |
downloader = MediaIoBaseDownload(buf, req)
|
|
@@ -80,26 +127,26 @@ class GDriveStore(BlobStore):
|
|
| 80 |
return buf.getvalue()
|
| 81 |
|
| 82 |
def write_bytes(self, key: str, data: bytes, content_type: Optional[str] = None) -> None:
|
|
|
|
|
|
|
| 83 |
media = MediaIoBaseUpload(io.BytesIO(data), mimetype=content_type or "application/octet-stream", resumable=False)
|
| 84 |
-
fid
|
| 85 |
-
|
| 86 |
-
if fid:
|
| 87 |
self.svc.files().update(fileId=fid, media_body=media, supportsAllDrives=True).execute()
|
| 88 |
-
|
|
|
|
| 89 |
self.svc.files().create(body=meta, media_body=media, supportsAllDrives=True, fields="id").execute()
|
| 90 |
|
| 91 |
def head(self, key: str) -> BlobInfo:
|
| 92 |
-
fid = self.
|
| 93 |
-
if not fid:
|
| 94 |
-
raise FileNotFoundError(key)
|
| 95 |
f = self.svc.files().get(
|
| 96 |
fileId=fid,
|
| 97 |
fields="id,name,size,modifiedTime,mimeType",
|
| 98 |
-
supportsAllDrives=True
|
| 99 |
).execute()
|
| 100 |
return BlobInfo(
|
| 101 |
-
f
|
| 102 |
size=int(f.get("size", 0)) if f.get("size") else None,
|
| 103 |
modified=f.get("modifiedTime"),
|
| 104 |
-
is_dir=(f
|
| 105 |
)
|
|
|
|
| 1 |
# cloud/storage_gdrive.py
|
| 2 |
import io
|
| 3 |
+
from typing import Iterable, Optional, List
|
| 4 |
+
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
from google.oauth2 import service_account as gsa
|
| 7 |
+
from google.oauth2.credentials import Credentials as UserCreds
|
| 8 |
from googleapiclient.discovery import build
|
| 9 |
from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload
|
| 10 |
|
| 11 |
from storage import BlobStore, BlobInfo
|
| 12 |
|
| 13 |
+
# Read-only scopes (sufficient for listing + downloading images)
|
| 14 |
+
_SCOPES = [
|
| 15 |
+
"https://www.googleapis.com/auth/drive.readonly",
|
| 16 |
+
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class _FileRow:
|
| 21 |
+
id: str
|
| 22 |
+
name: str
|
| 23 |
+
size: Optional[int]
|
| 24 |
+
modified: Optional[str]
|
| 25 |
+
mime: str
|
| 26 |
|
| 27 |
class GDriveStore(BlobStore):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
folder_id: Optional[str] = None,
|
| 31 |
+
*,
|
| 32 |
+
credentials: Optional[UserCreds] = None, # user OAuth (device flow)
|
| 33 |
+
service_account_json: Optional[str] = None, # SA JSON (optional alt)
|
| 34 |
+
):
|
| 35 |
+
self.root_folder_id = folder_id or "root"
|
| 36 |
|
| 37 |
+
if credentials is not None:
|
| 38 |
+
self.creds = credentials
|
| 39 |
+
elif service_account_json:
|
| 40 |
+
self.creds = gsa.Credentials.from_service_account_file(
|
| 41 |
+
service_account_json, scopes=_SCOPES
|
| 42 |
+
)
|
| 43 |
else:
|
| 44 |
+
raise ValueError(
|
| 45 |
+
"GDriveStore requires user OAuth credentials or a service_account_json path."
|
| 46 |
+
)
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
self.svc = build("drive", "v3", credentials=self.creds, cache_discovery=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
# ---------------- internal helpers ----------------
|
| 51 |
+
def _iter_files(self, q: str) -> Iterable[_FileRow]:
|
| 52 |
page_token = None
|
|
|
|
| 53 |
while True:
|
| 54 |
resp = self.svc.files().list(
|
| 55 |
+
q=q,
|
| 56 |
+
fields="nextPageToken, files(id,name,size,modifiedTime,mimeType)",
|
| 57 |
pageToken=page_token,
|
| 58 |
corpora="allDrives",
|
| 59 |
includeItemsFromAllDrives=True,
|
| 60 |
supportsAllDrives=True,
|
| 61 |
).execute()
|
| 62 |
for f in resp.get("files", []):
|
| 63 |
+
yield _FileRow(
|
| 64 |
+
id=f["id"],
|
| 65 |
+
name=f["name"],
|
| 66 |
+
size=int(f["size"]) if f.get("size") else None,
|
| 67 |
modified=f.get("modifiedTime"),
|
| 68 |
+
mime=f.get("mimeType", ""),
|
| 69 |
)
|
| 70 |
page_token = resp.get("nextPageToken")
|
| 71 |
if not page_token:
|
| 72 |
break
|
| 73 |
|
| 74 |
+
def _children_of(self, folder_id: str, name_filter: str = "") -> Iterable[_FileRow]:
|
| 75 |
+
q = f"'{folder_id}' in parents and trashed=false"
|
| 76 |
+
if name_filter:
|
| 77 |
+
# Drive query uses single quotes; no special characters expected in 'contains'
|
| 78 |
+
q += f" and name contains '{name_filter}'"
|
| 79 |
+
yield from self._iter_files(q)
|
| 80 |
+
|
| 81 |
+
def _resolve_path(self, prefix_path: str) -> str:
|
| 82 |
+
"""
|
| 83 |
+
Resolve a 'folder1/folder2' path starting from self.root_folder_id and return the final folder ID.
|
| 84 |
+
If any component is missing, returns self.root_folder_id (so listing won't crash).
|
| 85 |
+
"""
|
| 86 |
+
if not prefix_path or prefix_path.strip() == "":
|
| 87 |
+
return self.root_folder_id
|
| 88 |
+
cur = self.root_folder_id
|
| 89 |
+
parts: List[str] = [p for p in prefix_path.split("/") if p]
|
| 90 |
+
for seg in parts:
|
| 91 |
+
# find a child folder with this name
|
| 92 |
+
q = f"'{cur}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and name='{seg}'"
|
| 93 |
+
match = next(self._iter_files(q), None)
|
| 94 |
+
if not match:
|
| 95 |
+
# folder segment not found; fall back to current (best effort)
|
| 96 |
+
return cur
|
| 97 |
+
cur = match.id
|
| 98 |
+
return cur
|
| 99 |
+
|
| 100 |
+
def _extract_id(self, key: str) -> str:
|
| 101 |
+
# Accept "id|name" or bare "id"
|
| 102 |
+
return key.split("|", 1)[0]
|
| 103 |
+
|
| 104 |
+
# ---------------- BlobStore implementation ----------------
|
| 105 |
+
def list(self, prefix: str = "", recursive: bool = False) -> Iterable[BlobInfo]:
|
| 106 |
+
# Treat prefix as a folder path (like S3). Final path component is a folder.
|
| 107 |
+
folder_id = self._resolve_path(prefix)
|
| 108 |
+
for f in self._children_of(folder_id):
|
| 109 |
+
is_dir = (f.mime == "application/vnd.google-apps.folder")
|
| 110 |
+
# Use "id|name" so the UI can display a friendly name but we keep a unique key
|
| 111 |
+
key = f"{f.id}|{f.name}"
|
| 112 |
+
yield BlobInfo(
|
| 113 |
+
key=key,
|
| 114 |
+
size=f.size,
|
| 115 |
+
modified=f.modified,
|
| 116 |
+
is_dir=is_dir,
|
| 117 |
+
)
|
| 118 |
|
| 119 |
def read_bytes(self, key: str) -> bytes:
|
| 120 |
+
fid = self._extract_id(key)
|
|
|
|
|
|
|
| 121 |
req = self.svc.files().get_media(fileId=fid, supportsAllDrives=True)
|
| 122 |
buf = io.BytesIO()
|
| 123 |
downloader = MediaIoBaseDownload(buf, req)
|
|
|
|
| 127 |
return buf.getvalue()
|
| 128 |
|
| 129 |
def write_bytes(self, key: str, data: bytes, content_type: Optional[str] = None) -> None:
|
| 130 |
+
# Writes are not used by your app, but kept for interface parity.
|
| 131 |
+
fid = self._extract_id(key)
|
| 132 |
media = MediaIoBaseUpload(io.BytesIO(data), mimetype=content_type or "application/octet-stream", resumable=False)
|
| 133 |
+
# If fid refers to an existing file, update; otherwise create in root.
|
| 134 |
+
try:
|
|
|
|
| 135 |
self.svc.files().update(fileId=fid, media_body=media, supportsAllDrives=True).execute()
|
| 136 |
+
except Exception:
|
| 137 |
+
meta = {"name": key.split("|",1)[-1], "parents": [self.root_folder_id]}
|
| 138 |
self.svc.files().create(body=meta, media_body=media, supportsAllDrives=True, fields="id").execute()
|
| 139 |
|
| 140 |
def head(self, key: str) -> BlobInfo:
|
| 141 |
+
fid = self._extract_id(key)
|
|
|
|
|
|
|
| 142 |
f = self.svc.files().get(
|
| 143 |
fileId=fid,
|
| 144 |
fields="id,name,size,modifiedTime,mimeType",
|
| 145 |
+
supportsAllDrives=True,
|
| 146 |
).execute()
|
| 147 |
return BlobInfo(
|
| 148 |
+
key=f"{f['id']}|{f['name']}",
|
| 149 |
size=int(f.get("size", 0)) if f.get("size") else None,
|
| 150 |
modified=f.get("modifiedTime"),
|
| 151 |
+
is_dir=(f.get("mimeType") == "application/vnd.google-apps.folder"),
|
| 152 |
)
|
infer.py
CHANGED
|
@@ -13,7 +13,7 @@ from torchcam.methods import GradCAM
|
|
| 13 |
from torchcam.utils import overlay_mask
|
| 14 |
|
| 15 |
# Track last resolution for Diagnostics
|
| 16 |
-
LAST_WEIGHT_SOURCE = None # one of: local | s3 | gdrive | hf_manifest
|
| 17 |
LAST_WEIGHT_DETAIL = "" # key/repo/filename etc
|
| 18 |
LAST_WEIGHT_PATH = "" # final local path used
|
| 19 |
|
|
@@ -241,14 +241,14 @@ def _resolve_ckpt(ckpt_path: str, model_name: str, models_root: str) -> str:
|
|
| 241 |
|
| 242 |
# Order: depends on mode
|
| 243 |
if mode == "off":
|
| 244 |
-
order = ["local", "hf_manifest"
|
| 245 |
elif mode == "prefer_cloud":
|
| 246 |
-
order = ["local", "cloud", "cloud_auto", "hf_manifest"
|
| 247 |
elif mode == "prefer_hf":
|
| 248 |
-
order = ["local", "hf_manifest", "
|
| 249 |
else: # auto (per-allow-list)
|
| 250 |
-
order = ["local", "
|
| 251 |
-
["local", "
|
| 252 |
|
| 253 |
# Try sources in decided order
|
| 254 |
for src in order:
|
|
@@ -287,21 +287,10 @@ def _resolve_ckpt(ckpt_path: str, model_name: str, models_root: str) -> str:
|
|
| 287 |
except Exception as e:
|
| 288 |
_log(f"Manifest read error: {e}")
|
| 289 |
|
| 290 |
-
if src == "hf_env":
|
| 291 |
-
repo = os.getenv("HF_MODEL_REPO")
|
| 292 |
-
if repo:
|
| 293 |
-
fname = os.getenv("HF_MODEL_FILE", os.path.basename(ckpt_path) or "best.pth")
|
| 294 |
-
_log(f"Downloading from env → repo={repo} file={fname}")
|
| 295 |
-
path = hf_hub_download(repo_id=repo, filename=fname, token=os.getenv("HF_TOKEN"))
|
| 296 |
-
LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH = "hf_env", f"{repo}:{fname}", path
|
| 297 |
-
_log(f"Downloaded to cache: {path}")
|
| 298 |
-
return path
|
| 299 |
-
|
| 300 |
# If we got here, nothing worked
|
| 301 |
parts = [f"local={'yes' if avail_local else 'no'}",
|
| 302 |
f"cloud={'yes' if cloud_exists else 'no'} key={cloud_key or '-'} err={cloud_err or '-'}",
|
| 303 |
-
f"hf_manifest={'present' if (Path(models_root)/'manifest.json').exists() else 'absent'}"
|
| 304 |
-
f"hf_env={'yes' if os.getenv('HF_MODEL_REPO') else 'no'}"]
|
| 305 |
raise FileNotFoundError("Could not resolve weights. Probes: " + ", ".join(parts))
|
| 306 |
|
| 307 |
#--- main loading function---
|
|
|
|
| 13 |
from torchcam.utils import overlay_mask
|
| 14 |
|
| 15 |
# Track last resolution for Diagnostics
|
| 16 |
+
LAST_WEIGHT_SOURCE = None # one of: local | s3 | gdrive | hf_manifest
|
| 17 |
LAST_WEIGHT_DETAIL = "" # key/repo/filename etc
|
| 18 |
LAST_WEIGHT_PATH = "" # final local path used
|
| 19 |
|
|
|
|
| 241 |
|
| 242 |
# Order: depends on mode
|
| 243 |
if mode == "off":
|
| 244 |
+
order = ["local", "hf_manifest"]
|
| 245 |
elif mode == "prefer_cloud":
|
| 246 |
+
order = ["local", "cloud", "cloud_auto", "hf_manifest"]
|
| 247 |
elif mode == "prefer_hf":
|
| 248 |
+
order = ["local", "hf_manifest", "cloud", "cloud_auto"]
|
| 249 |
else: # auto (per-allow-list)
|
| 250 |
+
order = ["local", "cloud", "cloud_auto", "hf_manifest"] if _cloud_mode_for(model_name) == "prefer_cloud" \
|
| 251 |
+
else ["local", "hf_manifest", "cloud", "cloud_auto"]
|
| 252 |
|
| 253 |
# Try sources in decided order
|
| 254 |
for src in order:
|
|
|
|
| 287 |
except Exception as e:
|
| 288 |
_log(f"Manifest read error: {e}")
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
# If we got here, nothing worked
|
| 291 |
parts = [f"local={'yes' if avail_local else 'no'}",
|
| 292 |
f"cloud={'yes' if cloud_exists else 'no'} key={cloud_key or '-'} err={cloud_err or '-'}",
|
| 293 |
+
f"hf_manifest={'present' if (Path(models_root)/'manifest.json').exists() else 'absent'}"]
|
|
|
|
| 294 |
raise FileNotFoundError("Could not resolve weights. Probes: " + ", ".join(parts))
|
| 295 |
|
| 296 |
#--- main loading function---
|
requirements.txt
CHANGED
|
@@ -13,4 +13,5 @@ google-api-python-client>=2.137.0
|
|
| 13 |
google-auth>=2.34.0
|
| 14 |
google-auth-httplib2>=0.2.0
|
| 15 |
google-auth-oauthlib>=1.2.0
|
|
|
|
| 16 |
matplotlib>=3.8
|
|
|
|
| 13 |
google-auth>=2.34.0
|
| 14 |
google-auth-httplib2>=0.2.0
|
| 15 |
google-auth-oauthlib>=1.2.0
|
| 16 |
+
requests>=2.31
|
| 17 |
matplotlib>=3.8
|
storage.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
# storage.py
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
-
import io
|
| 5 |
-
import os
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Iterable, Optional
|
|
@@ -42,19 +41,31 @@ class BlobStore:
|
|
| 42 |
|
| 43 |
# ---------- Local FS baseline ----------
|
| 44 |
class LocalStore(BlobStore):
|
| 45 |
-
def __init__(self, root: Path | None = None):
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
if root:
|
| 51 |
-
|
| 52 |
elif os.getenv("APP_DATA_DIR"):
|
| 53 |
-
|
| 54 |
else:
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
def _p(self, key: str) -> Path:
|
| 60 |
return self.root / key
|
|
|
|
| 1 |
# storage.py
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
+
import io, tempfile, os
|
|
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import Iterable, Optional
|
|
|
|
| 41 |
|
| 42 |
# ---------- Local FS baseline ----------
|
| 43 |
class LocalStore(BlobStore):
|
| 44 |
+
def __init__(self, root: str | Path | None = None):
|
| 45 |
+
"""
|
| 46 |
+
Always land on a writable directory.
|
| 47 |
+
Priority:
|
| 48 |
+
1) explicit root (if provided)
|
| 49 |
+
2) APP_DATA_DIR (if set)
|
| 50 |
+
3) /tmp/label_assistant (always writable on Spaces/containers)
|
| 51 |
+
"""
|
| 52 |
+
# choose base
|
| 53 |
if root:
|
| 54 |
+
base = Path(root)
|
| 55 |
elif os.getenv("APP_DATA_DIR"):
|
| 56 |
+
base = Path(os.getenv("APP_DATA_DIR"))
|
| 57 |
else:
|
| 58 |
+
base = Path(tempfile.gettempdir()) / "label_assistant"
|
| 59 |
+
|
| 60 |
+
# ensure it exists; if it fails, force /tmp
|
| 61 |
+
try:
|
| 62 |
+
base.mkdir(parents=True, exist_ok=True)
|
| 63 |
+
except Exception:
|
| 64 |
+
base = Path(tempfile.gettempdir()) / "label_assistant"
|
| 65 |
+
base.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
self.root = base
|
| 68 |
+
print(f"[LocalStore] using {self.root}", flush=True)
|
| 69 |
|
| 70 |
def _p(self, key: str) -> Path:
|
| 71 |
return self.root / key
|
ui/sidebar.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
| 1 |
# ui/sidebar.py
|
| 2 |
-
import os, io, time
|
| 3 |
from pathlib import Path
|
| 4 |
import streamlit as st
|
| 5 |
from PIL import Image
|
| 6 |
-
from storage import
|
| 7 |
from cloud.storage_s3 import S3Store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# ------------- small logger -------------
|
| 10 |
def _log_s3(msg: str):
|
|
@@ -12,7 +17,7 @@ def _log_s3(msg: str):
|
|
| 12 |
st.session_state.setdefault("aws_logs", [])
|
| 13 |
st.session_state.aws_logs.append(f"[{ts}] {msg}")
|
| 14 |
|
| 15 |
-
def
|
| 16 |
with st.sidebar.expander("S3 logs", expanded=False):
|
| 17 |
logs = st.session_state.get("aws_logs", [])
|
| 18 |
if not logs:
|
|
@@ -25,22 +30,57 @@ def _logs_panel():
|
|
| 25 |
def build_sidebar_datasource(BASE_DIR: Path):
|
| 26 |
st.sidebar.markdown("---")
|
| 27 |
st.sidebar.subheader("Data source")
|
| 28 |
-
src_choice = st.sidebar.radio(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
if src_choice == "Local":
|
|
|
|
| 31 |
store = LocalStore()
|
| 32 |
-
_logs_panel()
|
| 33 |
return store, src_choice
|
| 34 |
|
| 35 |
if src_choice == "AWS S3":
|
| 36 |
-
store = _s3_connect_ui()
|
| 37 |
return store, src_choice
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
# ------------- S3 connection UI -------------
|
|
@@ -109,9 +149,9 @@ def _s3_connect_ui():
|
|
| 109 |
_remote_browser_ui(store)
|
| 110 |
else:
|
| 111 |
st.info("Enter your S3 details and click **Connect**.")
|
| 112 |
-
|
| 113 |
# When not connected, return harmless LocalStore so the rest of the app doesn't break
|
| 114 |
-
return store if st.session_state.get("aws_store_ok") else LocalStore(
|
| 115 |
|
| 116 |
# ------------- S3 browser -------------
|
| 117 |
def _remote_browser_ui(STORE: S3Store):
|
|
@@ -190,6 +230,422 @@ def _remote_browser_ui(STORE: S3Store):
|
|
| 190 |
st.sidebar.info("No images found for this prefix. Try a different filter or check your Prefix above.")
|
| 191 |
_log_s3(f"No images found for '{list_prefix}'")
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
# ------------- Loaded images sidebar -------------
|
| 194 |
def show_loaded_images_sidebar():
|
| 195 |
st.sidebar.markdown("---")
|
|
|
|
| 1 |
# ui/sidebar.py
|
| 2 |
+
import os, io, time, json, requests
|
| 3 |
from pathlib import Path
|
| 4 |
import streamlit as st
|
| 5 |
from PIL import Image
|
| 6 |
+
from storage import LocalStore
|
| 7 |
from cloud.storage_s3 import S3Store
|
| 8 |
+
from google.oauth2.credentials import Credentials
|
| 9 |
+
from cloud.storage_gdrive import GDriveStore
|
| 10 |
+
from googleapiclient.discovery import build
|
| 11 |
+
from google_auth_oauthlib.flow import Flow
|
| 12 |
+
from streamlit.components.v1 import html
|
| 13 |
|
| 14 |
# ------------- small logger -------------
|
| 15 |
def _log_s3(msg: str):
|
|
|
|
| 17 |
st.session_state.setdefault("aws_logs", [])
|
| 18 |
st.session_state.aws_logs.append(f"[{ts}] {msg}")
|
| 19 |
|
| 20 |
+
def _s3_logs_panel():
|
| 21 |
with st.sidebar.expander("S3 logs", expanded=False):
|
| 22 |
logs = st.session_state.get("aws_logs", [])
|
| 23 |
if not logs:
|
|
|
|
| 30 |
def build_sidebar_datasource(BASE_DIR: Path):
|
| 31 |
st.sidebar.markdown("---")
|
| 32 |
st.sidebar.subheader("Data source")
|
| 33 |
+
src_choice = st.sidebar.radio(
|
| 34 |
+
"Load from",
|
| 35 |
+
["Local", "AWS S3", "Google Drive"],
|
| 36 |
+
index=0,
|
| 37 |
+
horizontal=False
|
| 38 |
+
)
|
| 39 |
|
| 40 |
if src_choice == "Local":
|
| 41 |
+
os.environ.setdefault("APP_DATA_DIR", "/tmp/label_assistant") # never /app
|
| 42 |
store = LocalStore()
|
|
|
|
| 43 |
return store, src_choice
|
| 44 |
|
| 45 |
if src_choice == "AWS S3":
|
| 46 |
+
store = _s3_connect_ui() # includes logs + browser
|
| 47 |
return store, src_choice
|
| 48 |
|
| 49 |
+
if src_choice == "Google Drive":
|
| 50 |
+
with st.sidebar:
|
| 51 |
+
# (0) If a redirect URL was staged, perform it immediately (same tab).
|
| 52 |
+
redir = st.session_state.pop("_oauth_redirect", None)
|
| 53 |
+
if redir:
|
| 54 |
+
html(f'<script>window.location.replace("{redir}");</script>', height=0)
|
| 55 |
+
st.stop()
|
| 56 |
+
|
| 57 |
+
# (1) Handle callback first (if we just came back from Google)
|
| 58 |
+
creds = _get_pkce_creds_from_session() or _handle_google_callback_pkce()
|
| 59 |
+
|
| 60 |
+
# (2) Not signed in yet → start PKCE and force a reliable redirect
|
| 61 |
+
if not creds:
|
| 62 |
+
# Generate the auth URL and show it as a clickable link
|
| 63 |
+
url = _start_google_oauth_pkce()
|
| 64 |
+
if url:
|
| 65 |
+
st.link_button("Sign in with Google", url, use_container_width=True)
|
| 66 |
+
st.caption("After granting access, you'll be redirected back here automatically.")
|
| 67 |
+
os.environ.setdefault("APP_DATA_DIR", "/tmp/label_assistant")
|
| 68 |
+
return LocalStore(), src_choice
|
| 69 |
+
|
| 70 |
+
# (3) Show whoami + sign-out once signed in
|
| 71 |
+
email, who = _drive_whoami(creds)
|
| 72 |
+
st.success(f"Signed in as **{who}** · {email}")
|
| 73 |
+
if st.button("Sign out of Google", use_container_width=True):
|
| 74 |
+
for k in ("g_creds", "g_state", "g_client_config", "_oauth_redirect"):
|
| 75 |
+
st.session_state.pop(k, None)
|
| 76 |
+
st.rerun()
|
| 77 |
+
|
| 78 |
+
# (4) Use Drive store once signed in
|
| 79 |
+
store = GDriveStore(credentials=creds)
|
| 80 |
+
_gdrive_browser_ui(store)
|
| 81 |
+
return store, src_choice
|
| 82 |
+
|
| 83 |
+
return LocalStore(), "Local"
|
| 84 |
|
| 85 |
|
| 86 |
# ------------- S3 connection UI -------------
|
|
|
|
| 149 |
_remote_browser_ui(store)
|
| 150 |
else:
|
| 151 |
st.info("Enter your S3 details and click **Connect**.")
|
| 152 |
+
_s3_logs_panel()
|
| 153 |
# When not connected, return harmless LocalStore so the rest of the app doesn't break
|
| 154 |
+
return store if st.session_state.get("aws_store_ok") else LocalStore()
|
| 155 |
|
| 156 |
# ------------- S3 browser -------------
|
| 157 |
def _remote_browser_ui(STORE: S3Store):
|
|
|
|
| 230 |
st.sidebar.info("No images found for this prefix. Try a different filter or check your Prefix above.")
|
| 231 |
_log_s3(f"No images found for '{list_prefix}'")
|
| 232 |
|
| 233 |
+
# ---------- Google Drive connection + browser ----------
|
| 234 |
+
# ---------- Google Drive (PKCE OAuth) ----------
|
| 235 |
+
DRIVE_SCOPES = [
|
| 236 |
+
"https://www.googleapis.com/auth/drive.readonly",
|
| 237 |
+
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
| 238 |
+
"https://www.googleapis.com/auth/drive.file",
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
def _start_google_oauth_pkce():
|
| 242 |
+
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
| 243 |
+
client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
| 244 |
+
redirect_uri = os.getenv("GOOGLE_OAUTH_REDIRECT_URI")
|
| 245 |
+
if not (client_id and client_secret and redirect_uri):
|
| 246 |
+
st.error("Missing GOOGLE_OAUTH_CLIENT_ID / GOOGLE_OAUTH_CLIENT_SECRET / GOOGLE_OAUTH_REDIRECT_URI.")
|
| 247 |
+
return None
|
| 248 |
+
|
| 249 |
+
client_config = {
|
| 250 |
+
"web": {
|
| 251 |
+
"client_id": client_id,
|
| 252 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
| 253 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
| 254 |
+
"client_secret": client_secret,
|
| 255 |
+
"redirect_uris": [redirect_uri],
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
flow = Flow.from_client_config(client_config, scopes=DRIVE_SCOPES, redirect_uri=redirect_uri)
|
| 259 |
+
# prompt=consent ensures refresh_token on repeated logins during testing
|
| 260 |
+
auth_url, state = flow.authorization_url(access_type="offline", include_granted_scopes="true", prompt="consent")
|
| 261 |
+
st.session_state["g_state"] = state
|
| 262 |
+
st.session_state["g_client_config"] = client_config
|
| 263 |
+
return auth_url
|
| 264 |
+
|
| 265 |
+
def _handle_google_callback_pkce():
|
| 266 |
+
params = st.query_params
|
| 267 |
+
if "code" not in params or "state" not in params:
|
| 268 |
+
return None
|
| 269 |
+
|
| 270 |
+
code = params.get("code")
|
| 271 |
+
state = params.get("state")
|
| 272 |
+
if isinstance(code, list): code = code[0]
|
| 273 |
+
if isinstance(state, list): state = state[0]
|
| 274 |
+
|
| 275 |
+
# Enforce state only if we have one saved (same-session case).
|
| 276 |
+
expected = st.session_state.get("g_state")
|
| 277 |
+
if expected and state != expected:
|
| 278 |
+
st.error("OAuth state mismatch. Please try signing in again.")
|
| 279 |
+
return None
|
| 280 |
+
|
| 281 |
+
# Try to get client config from session; if missing (new session), rebuild from ENV.
|
| 282 |
+
client_config = st.session_state.get("g_client_config")
|
| 283 |
+
if not client_config:
|
| 284 |
+
cid = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
| 285 |
+
csec = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
| 286 |
+
ruri = os.getenv("GOOGLE_OAUTH_REDIRECT_URI")
|
| 287 |
+
if not (cid and csec and ruri):
|
| 288 |
+
st.error("Missing GOOGLE_OAUTH_CLIENT_ID/SECRET/REDIRECT_URI while handling callback.")
|
| 289 |
+
return None
|
| 290 |
+
client_config = {
|
| 291 |
+
"web": {
|
| 292 |
+
"client_id": cid,
|
| 293 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
| 294 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
| 295 |
+
"client_secret": csec,
|
| 296 |
+
"redirect_uris": [ruri],
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
redirect_uri = client_config["web"]["redirect_uris"][0]
|
| 301 |
+
flow = Flow.from_client_config(client_config, scopes=DRIVE_SCOPES, redirect_uri=redirect_uri)
|
| 302 |
+
try:
|
| 303 |
+
flow.fetch_token(code=code)
|
| 304 |
+
except Exception as e:
|
| 305 |
+
st.error(f"OAuth token exchange failed: {e}")
|
| 306 |
+
return None
|
| 307 |
+
|
| 308 |
+
creds = flow.credentials
|
| 309 |
+
st.session_state["g_creds"] = {
|
| 310 |
+
"token": creds.token,
|
| 311 |
+
"refresh_token": creds.refresh_token,
|
| 312 |
+
"token_uri": creds.token_uri,
|
| 313 |
+
"client_id": creds.client_id,
|
| 314 |
+
"client_secret": creds.client_secret,
|
| 315 |
+
"scopes": creds.scopes,
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
# Clear query params so we don’t reprocess the callback on every rerun
|
| 319 |
+
try:
|
| 320 |
+
st.query_params.clear()
|
| 321 |
+
except Exception:
|
| 322 |
+
pass
|
| 323 |
+
return creds
|
| 324 |
+
|
| 325 |
+
def _get_pkce_creds_from_session():
|
| 326 |
+
data = st.session_state.get("g_creds")
|
| 327 |
+
return Credentials(**data) if data else None
|
| 328 |
+
|
| 329 |
+
def _gdrive_display_name(k: str) -> str:
|
| 330 |
+
return k.split("|", 1)[1] if "|" in k else k
|
| 331 |
+
|
| 332 |
+
def _gdev_start():
|
| 333 |
+
cid = os.getenv("GOOGLE_OAUTH_CLIENT_ID", "").strip()
|
| 334 |
+
if not cid:
|
| 335 |
+
st.error("Missing GOOGLE_OAUTH_CLIENT_ID env. Set it in your Space/local .env.")
|
| 336 |
+
return
|
| 337 |
+
if not cid.endswith(".apps.googleusercontent.com"):
|
| 338 |
+
st.warning("GOOGLE_OAUTH_CLIENT_ID doesn't look like a valid OAuth client ID (…apps.googleusercontent.com). Double-check the value.")
|
| 339 |
+
|
| 340 |
+
try:
|
| 341 |
+
resp = requests.post(
|
| 342 |
+
"https://oauth2.googleapis.com/device/code",
|
| 343 |
+
data={"client_id": cid, "scope": DRIVE_SCOPES},
|
| 344 |
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
| 345 |
+
timeout=15,
|
| 346 |
+
)
|
| 347 |
+
except Exception as e:
|
| 348 |
+
st.error(f"Google device-code start failed: network error: {e}")
|
| 349 |
+
return
|
| 350 |
+
|
| 351 |
+
if resp.status_code != 200:
|
| 352 |
+
# Show the exact JSON/text so you know what's wrong
|
| 353 |
+
try:
|
| 354 |
+
err = resp.json()
|
| 355 |
+
except Exception:
|
| 356 |
+
err = resp.text
|
| 357 |
+
st.error(f"Device-code start failed: {resp.status_code} {err}")
|
| 358 |
+
st.session_state["gdev_error"] = err
|
| 359 |
+
return
|
| 360 |
+
|
| 361 |
+
data = resp.json()
|
| 362 |
+
st.session_state["gdev"] = data
|
| 363 |
+
st.session_state["gdev_started_at"] = time.time()
|
| 364 |
+
st.session_state["gcreds"] = None
|
| 365 |
+
|
| 366 |
+
def _gdev_poll():
|
| 367 |
+
if "gdev" not in st.session_state:
|
| 368 |
+
return
|
| 369 |
+
cid = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
| 370 |
+
csec = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET", "")
|
| 371 |
+
dev = st.session_state["gdev"]
|
| 372 |
+
try:
|
| 373 |
+
resp = requests.post(
|
| 374 |
+
"https://oauth2.googleapis.com/token",
|
| 375 |
+
data={
|
| 376 |
+
"client_id": cid,
|
| 377 |
+
"client_secret": csec, # optional for desktop/device, harmless if set
|
| 378 |
+
"device_code": dev["device_code"],
|
| 379 |
+
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
| 380 |
+
},
|
| 381 |
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
| 382 |
+
timeout=15,
|
| 383 |
+
)
|
| 384 |
+
except Exception as e:
|
| 385 |
+
st.error(f"Google token poll failed: network error: {e}")
|
| 386 |
+
return
|
| 387 |
+
|
| 388 |
+
if resp.status_code == 400:
|
| 389 |
+
err = (resp.json() or {}).get("error")
|
| 390 |
+
# expected while user hasn’t approved yet
|
| 391 |
+
if err in ("authorization_pending", "slow_down"):
|
| 392 |
+
return
|
| 393 |
+
st.error(f"Google auth error: {err}")
|
| 394 |
+
st.session_state["gdev_error"] = resp.text
|
| 395 |
+
return
|
| 396 |
+
|
| 397 |
+
if resp.status_code != 200:
|
| 398 |
+
st.error(f"Google token poll failed: {resp.status_code} {resp.text}")
|
| 399 |
+
st.session_state["gdev_error"] = resp.text
|
| 400 |
+
return
|
| 401 |
+
|
| 402 |
+
tok = resp.json()
|
| 403 |
+
creds = Credentials(
|
| 404 |
+
token=tok["access_token"],
|
| 405 |
+
refresh_token=tok.get("refresh_token"),
|
| 406 |
+
token_uri="https://oauth2.googleapis.com/token",
|
| 407 |
+
client_id=cid,
|
| 408 |
+
client_secret=csec or None,
|
| 409 |
+
scopes=DRIVE_SCOPES.split(),
|
| 410 |
+
)
|
| 411 |
+
st.session_state["gcreds"] = {
|
| 412 |
+
"token": creds.token,
|
| 413 |
+
"refresh_token": creds.refresh_token,
|
| 414 |
+
"token_uri": creds.token_uri,
|
| 415 |
+
"client_id": creds.client_id,
|
| 416 |
+
"client_secret": creds.client_secret,
|
| 417 |
+
"scopes": creds.scopes,
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
def _drive_whoami(creds):
|
| 421 |
+
"""Return (email, name) for the signed-in account."""
|
| 422 |
+
try:
|
| 423 |
+
svc = build("drive", "v3", credentials=creds, cache_discovery=False)
|
| 424 |
+
me = svc.about().get(fields="user(emailAddress, displayName)").execute()
|
| 425 |
+
u = me.get("user", {}) or {}
|
| 426 |
+
return u.get("emailAddress") or "unknown", u.get("displayName") or "unknown"
|
| 427 |
+
except Exception as e:
|
| 428 |
+
return None, f"whoami failed: {e}"
|
| 429 |
+
|
| 430 |
+
def _drive_debug_list_root(creds, folder_id="root"):
|
| 431 |
+
"""Return count and sample names from the current root to verify listing works."""
|
| 432 |
+
try:
|
| 433 |
+
svc = build("drive", "v3", credentials=creds, cache_discovery=False)
|
| 434 |
+
q = f"'{folder_id or 'root'}' in parents and trashed=false"
|
| 435 |
+
resp = svc.files().list(
|
| 436 |
+
q=q,
|
| 437 |
+
fields="files(id,name,mimeType), nextPageToken",
|
| 438 |
+
corpora="allDrives",
|
| 439 |
+
includeItemsFromAllDrives=True,
|
| 440 |
+
supportsAllDrives=True,
|
| 441 |
+
pageSize=20
|
| 442 |
+
).execute()
|
| 443 |
+
files = resp.get("files", [])
|
| 444 |
+
names = [f["name"] for f in files[:10]]
|
| 445 |
+
return len(files), names
|
| 446 |
+
except Exception as e:
|
| 447 |
+
return None, f"list failed: {e}"
|
| 448 |
+
|
| 449 |
+
def _drive_login_block():
|
| 450 |
+
"""UI block for Google Drive sign-in via Device Code."""
|
| 451 |
+
creds_data = st.session_state.get("gcreds")
|
| 452 |
+
if creds_data:
|
| 453 |
+
st.success("Connected to Google Drive.")
|
| 454 |
+
if st.button("Sign out of Google", use_container_width=True):
|
| 455 |
+
st.session_state.pop("gcreds", None)
|
| 456 |
+
st.session_state.pop("gdev", None)
|
| 457 |
+
st.rerun()
|
| 458 |
+
return Credentials(**creds_data)
|
| 459 |
+
|
| 460 |
+
dev = st.session_state.get("gdev")
|
| 461 |
+
if not dev:
|
| 462 |
+
if st.button("Sign in with Google", use_container_width=True):
|
| 463 |
+
# clear a prior error so it doesn’t linger after a new attempt
|
| 464 |
+
st.session_state.pop("gdev_error", None)
|
| 465 |
+
_gdev_start()
|
| 466 |
+
st.rerun()
|
| 467 |
+
|
| 468 |
+
# ← show the last Google response even if device-code init failed
|
| 469 |
+
ge = st.session_state.get("gdev_error")
|
| 470 |
+
if ge:
|
| 471 |
+
st.sidebar.markdown("**Google auth error (last response):**")
|
| 472 |
+
st.sidebar.code(json.dumps(ge, indent=2) if isinstance(ge, dict) else str(ge))
|
| 473 |
+
|
| 474 |
+
st.caption("You’ll use a short code on accounts.google.com/device to grant Drive read-only access.")
|
| 475 |
+
return None
|
| 476 |
+
|
| 477 |
+
# Show code + link and poll
|
| 478 |
+
st.info(f"Go to **{dev['verification_url']}** and enter this code:\n\n**{dev['user_code']}**")
|
| 479 |
+
col1, col2 = st.columns([1,1])
|
| 480 |
+
with col1:
|
| 481 |
+
if st.button("I’ve approved", use_container_width=True):
|
| 482 |
+
_gdev_poll()
|
| 483 |
+
st.rerun()
|
| 484 |
+
with col2:
|
| 485 |
+
if st.button("Cancel", use_container_width=True):
|
| 486 |
+
st.session_state.pop("gdev", None)
|
| 487 |
+
st.rerun()
|
| 488 |
+
|
| 489 |
+
# Optional: auto-poll every few seconds
|
| 490 |
+
st.caption("Waiting for approval…")
|
| 491 |
+
_gdev_poll()
|
| 492 |
+
|
| 493 |
+
if "gdev_error" in st.session_state:
|
| 494 |
+
st.sidebar.caption("Google response:")
|
| 495 |
+
st.sidebar.code(json.dumps(st.session_state["gdev_error"], indent=2) if isinstance(st.session_state["gdev_error"], dict) else str(st.session_state["gdev_error"]))
|
| 496 |
+
|
| 497 |
+
return None
|
| 498 |
+
|
| 499 |
+
def _gdrive_connect_ui():
|
| 500 |
+
st.sidebar.markdown("**Connect to your Google Drive**")
|
| 501 |
+
with st.sidebar.expander("Drive connection", expanded=True):
|
| 502 |
+
# Two modes: Service Account (recommended on Spaces) OR OAuth Device Code.
|
| 503 |
+
mode = st.radio("Auth mode", ["Service Account (JSON)", "OAuth Device Code"], horizontal=False, key="gd_mode")
|
| 504 |
+
|
| 505 |
+
if mode == "Service Account (JSON)":
|
| 506 |
+
sa_json = st.text_input(
|
| 507 |
+
"Path to service account JSON (mounted/secret)",
|
| 508 |
+
value=st.session_state.get("gd_sa_json", os.getenv("GDRIVE_SERVICE_ACCOUNT_JSON","")),
|
| 509 |
+
key="gd_sa_json"
|
| 510 |
+
)
|
| 511 |
+
folder_id = st.text_input(
|
| 512 |
+
"Root Folder ID (share this folder with your service account email)",
|
| 513 |
+
value=st.session_state.get("gd_folder", os.getenv("GDRIVE_FOLDER_ID","")),
|
| 514 |
+
key="gd_folder"
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
c1, c2 = st.columns([1,1])
|
| 518 |
+
if c1.button("Connect", use_container_width=True):
|
| 519 |
+
try:
|
| 520 |
+
store = GDriveStore(service_account_json=sa_json or None, root_folder_id=folder_id or None)
|
| 521 |
+
# probe
|
| 522 |
+
_ = next(iter(store.list(prefix="", recursive=False)), None)
|
| 523 |
+
st.session_state.gd_store = {
|
| 524 |
+
"ok": True, "err": None, "mode": "sa",
|
| 525 |
+
"sa_json": sa_json, "folder": folder_id
|
| 526 |
+
}
|
| 527 |
+
st.success("Connected to Google Drive (Service Account).")
|
| 528 |
+
st.rerun()
|
| 529 |
+
except Exception as e:
|
| 530 |
+
st.session_state.gd_store = {"ok": False, "err": str(e)}
|
| 531 |
+
st.error(f"Connection failed: {e}")
|
| 532 |
+
|
| 533 |
+
if c2.button("Sign out", use_container_width=True):
|
| 534 |
+
st.session_state.pop("gd_store", None)
|
| 535 |
+
st.rerun()
|
| 536 |
+
|
| 537 |
+
else: # OAuth Device Code (user account)
|
| 538 |
+
client_id = st.text_input("OAuth Client ID", value=st.session_state.get("gd_cid",""), key="gd_cid")
|
| 539 |
+
client_secret = st.text_input("OAuth Client Secret", value=st.session_state.get("gd_csec",""), type="password", key="gd_csec")
|
| 540 |
+
folder_id = st.text_input("Root Folder ID (optional)", value=st.session_state.get("gd_folder",""), key="gd_folder")
|
| 541 |
+
|
| 542 |
+
c1, c2 = st.columns([1,1])
|
| 543 |
+
if c1.button("Connect", use_container_width=True):
|
| 544 |
+
try:
|
| 545 |
+
store = GDriveStore(
|
| 546 |
+
client_id=client_id or None,
|
| 547 |
+
client_secret=client_secret or None,
|
| 548 |
+
root_folder_id=folder_id or None,
|
| 549 |
+
oauth_mode="device" # GDriveStore should implement device-code flow
|
| 550 |
+
)
|
| 551 |
+
_ = next(iter(store.list(prefix="", recursive=False)), None)
|
| 552 |
+
st.session_state.gd_store = {
|
| 553 |
+
"ok": True, "err": None, "mode": "oauth",
|
| 554 |
+
"cid": client_id, "csec": client_secret, "folder": folder_id
|
| 555 |
+
}
|
| 556 |
+
st.success("Connected to Google Drive (OAuth).")
|
| 557 |
+
st.rerun()
|
| 558 |
+
except Exception as e:
|
| 559 |
+
st.session_state.gd_store = {"ok": False, "err": str(e)}
|
| 560 |
+
st.error(f"Connection failed: {e}")
|
| 561 |
+
|
| 562 |
+
if c2.button("Sign out", use_container_width=True):
|
| 563 |
+
st.session_state.pop("gd_store", None)
|
| 564 |
+
st.rerun()
|
| 565 |
+
|
| 566 |
+
# Connected banner + browser
|
| 567 |
+
gd = st.session_state.get("gd_store")
|
| 568 |
+
if gd and gd.get("ok"):
|
| 569 |
+
st.sidebar.success(f"Connected to Drive (mode={gd.get('mode')}).")
|
| 570 |
+
if gd["mode"] == "sa":
|
| 571 |
+
store = GDriveStore(service_account_json=gd.get("sa_json") or None,
|
| 572 |
+
root_folder_id=gd.get("folder") or None)
|
| 573 |
+
else:
|
| 574 |
+
store = GDriveStore(client_id=gd.get("cid") or None,
|
| 575 |
+
client_secret=gd.get("csec") or None,
|
| 576 |
+
root_folder_id=gd.get("folder") or None,
|
| 577 |
+
oauth_mode="device")
|
| 578 |
+
_gdrive_browser_ui(store)
|
| 579 |
+
return store
|
| 580 |
+
|
| 581 |
+
st.info("Enter your Drive credentials and click **Connect**.")
|
| 582 |
+
return LocalStore()
|
| 583 |
+
|
| 584 |
+
def _gdrive_browser_ui(STORE: GDriveStore):
|
| 585 |
+
st.sidebar.markdown("---")
|
| 586 |
+
st.sidebar.subheader("Browse Drive images")
|
| 587 |
+
|
| 588 |
+
list_prefix = st.sidebar.text_input("Subfolder filter (path / ‘folder1/folder2’)", value=st.session_state.get("gd_prefix",""))
|
| 589 |
+
c1, c2 = st.sidebar.columns(2)
|
| 590 |
+
|
| 591 |
+
if c1.button("List", use_container_width=True):
|
| 592 |
+
st.session_state.gd_prefix = list_prefix
|
| 593 |
+
items = []
|
| 594 |
+
try:
|
| 595 |
+
raw = list(STORE.list(prefix=list_prefix, recursive=True))
|
| 596 |
+
for b in raw:
|
| 597 |
+
if b.is_dir:
|
| 598 |
+
continue
|
| 599 |
+
if b.key.lower().endswith((".jpg",".jpeg",".png",".bmp",".webp")):
|
| 600 |
+
items.append(b)
|
| 601 |
+
except Exception as e:
|
| 602 |
+
st.sidebar.error(f"Drive list failed: {e}")
|
| 603 |
+
items = []
|
| 604 |
+
st.session_state.gd_list = items
|
| 605 |
+
|
| 606 |
+
if c2.button("Load all listed", use_container_width=True):
|
| 607 |
+
loaded = 0
|
| 608 |
+
for b in st.session_state.get("gd_list", []):
|
| 609 |
+
try:
|
| 610 |
+
data = STORE.read_bytes(b.key)
|
| 611 |
+
im = Image.open(io.BytesIO(data)).convert("RGB")
|
| 612 |
+
key = f"gdrive://{b.key}"
|
| 613 |
+
st.session_state.images[key] = {
|
| 614 |
+
"key": key, "name": _gdrive_display_name(b.key), "pil": im,
|
| 615 |
+
"boxes": [], "preds": [], "user_labels": [], "actions": [], "canvas_json": None
|
| 616 |
+
}
|
| 617 |
+
if st.session_state.active_key is None:
|
| 618 |
+
st.session_state.active_key = key
|
| 619 |
+
loaded += 1
|
| 620 |
+
except Exception as e:
|
| 621 |
+
st.sidebar.warning(f"Load failed for {b.key}: {e}")
|
| 622 |
+
if loaded:
|
| 623 |
+
st.sidebar.success(f"Loaded {loaded} image(s)")
|
| 624 |
+
st.rerun()
|
| 625 |
+
|
| 626 |
+
remote_items = st.session_state.get("gd_list", [])
|
| 627 |
+
if remote_items:
|
| 628 |
+
st.sidebar.caption(f"{len(remote_items)} image(s) listed under '{list_prefix or '(root)'}'")
|
| 629 |
+
for b in remote_items:
|
| 630 |
+
cols = st.sidebar.columns([0.7, 0.3])
|
| 631 |
+
cols[0].caption(f"🖼️ {_gdrive_display_name(b.key)}")
|
| 632 |
+
if cols[1].button("Load", key=f"gd_load_{b.key}"):
|
| 633 |
+
try:
|
| 634 |
+
data = STORE.read_bytes(b.key)
|
| 635 |
+
im = Image.open(io.BytesIO(data)).convert("RGB")
|
| 636 |
+
key = f"gdrive://{b.key}"
|
| 637 |
+
st.session_state.images[key] = {
|
| 638 |
+
"key": key, "name": _gdrive_display_name(b.key), "pil": im,
|
| 639 |
+
"boxes": [], "preds": [], "user_labels": [], "actions": [], "canvas_json": None
|
| 640 |
+
}
|
| 641 |
+
if st.session_state.active_key is None:
|
| 642 |
+
st.session_state.active_key = key
|
| 643 |
+
st.rerun()
|
| 644 |
+
except Exception as e:
|
| 645 |
+
st.sidebar.error(f"Load failed: {e}")
|
| 646 |
+
else:
|
| 647 |
+
st.sidebar.info("No images found. Try another subfolder.")
|
| 648 |
+
|
| 649 |
# ------------- Loaded images sidebar -------------
|
| 650 |
def show_loaded_images_sidebar():
|
| 651 |
st.sidebar.markdown("---")
|