Tan Zi Xu
S3 Integration and refactoring
c43c055
# cloud/storage_s3.py
import os
import boto3
from typing import Iterable, Optional
from storage import BlobStore, BlobInfo
class S3Store(BlobStore):
def __init__(
self,
bucket: str,
prefix: str = "",
region: Optional[str] = None,
endpoint_url: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
):
if not bucket:
raise ValueError("S3Store requires bucket")
self.bucket = bucket
self.prefix = prefix.strip("/")
# Prefer explicit creds (from the UI); fall back to env if not provided
self.s3 = boto3.client(
"s3",
region_name=region or os.getenv("AWS_REGION"),
endpoint_url=endpoint_url or os.getenv("AWS_ENDPOINT_URL"),
aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
)
def _k(self, key: str) -> str:
return f"{self.prefix}/{key}".strip("/") if self.prefix else key
def list(self, prefix: str = "", recursive: bool = False) -> Iterable[BlobInfo]:
pfx = self._k(prefix)
kwargs = dict(Bucket=self.bucket, Prefix=pfx)
while True:
resp = self.s3.list_objects_v2(**kwargs)
for obj in resp.get("Contents", []):
key = obj["Key"]
# one-level listing if recursive=False
if not recursive and "/" in key[len(pfx):].strip("/"):
continue
short = key[len(self.prefix)+1:] if self.prefix and key.startswith(self.prefix+"/") else key
yield BlobInfo(short, size=obj.get("Size"),
modified=(obj.get("LastModified").isoformat() if obj.get("LastModified") else None))
if resp.get("IsTruncated"):
kwargs["ContinuationToken"] = resp["NextContinuationToken"]
else:
break
def read_bytes(self, key: str) -> bytes:
r = self.s3.get_object(Bucket=self.bucket, Key=self._k(key))
return r["Body"].read()
def write_bytes(self, key: str, data: bytes, content_type: Optional[str] = None):
extra = {"ContentType": content_type} if content_type else {}
self.s3.put_object(Bucket=self.bucket, Key=self._k(key), Body=data, **extra)
def head(self, key: str) -> BlobInfo:
r = self.s3.head_object(Bucket=self.bucket, Key=self._k(key))
return BlobInfo(key, size=r.get("ContentLength"),
modified=(r.get("LastModified").isoformat() if r.get("LastModified") else None))