| |
| import os |
| import boto3 |
| from typing import Iterable, Optional |
| from storage import BlobStore, BlobInfo |
|
|
| class S3Store(BlobStore): |
| def __init__( |
| self, |
| bucket: str, |
| prefix: str = "", |
| region: Optional[str] = None, |
| endpoint_url: Optional[str] = None, |
| aws_access_key_id: Optional[str] = None, |
| aws_secret_access_key: Optional[str] = None, |
| aws_session_token: Optional[str] = None, |
| ): |
| if not bucket: |
| raise ValueError("S3Store requires bucket") |
| self.bucket = bucket |
| self.prefix = prefix.strip("/") |
| |
| self.s3 = boto3.client( |
| "s3", |
| region_name=region or os.getenv("AWS_REGION"), |
| endpoint_url=endpoint_url or os.getenv("AWS_ENDPOINT_URL"), |
| aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), |
| aws_secret_access_key=aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY"), |
| aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"), |
| ) |
|
|
| def _k(self, key: str) -> str: |
| return f"{self.prefix}/{key}".strip("/") if self.prefix else key |
|
|
| def list(self, prefix: str = "", recursive: bool = False) -> Iterable[BlobInfo]: |
| pfx = self._k(prefix) |
| kwargs = dict(Bucket=self.bucket, Prefix=pfx) |
| while True: |
| resp = self.s3.list_objects_v2(**kwargs) |
| for obj in resp.get("Contents", []): |
| key = obj["Key"] |
| |
| if not recursive and "/" in key[len(pfx):].strip("/"): |
| continue |
| short = key[len(self.prefix)+1:] if self.prefix and key.startswith(self.prefix+"/") else key |
| yield BlobInfo(short, size=obj.get("Size"), |
| modified=(obj.get("LastModified").isoformat() if obj.get("LastModified") else None)) |
| if resp.get("IsTruncated"): |
| kwargs["ContinuationToken"] = resp["NextContinuationToken"] |
| else: |
| break |
|
|
| def read_bytes(self, key: str) -> bytes: |
| r = self.s3.get_object(Bucket=self.bucket, Key=self._k(key)) |
| return r["Body"].read() |
|
|
| def write_bytes(self, key: str, data: bytes, content_type: Optional[str] = None): |
| extra = {"ContentType": content_type} if content_type else {} |
| self.s3.put_object(Bucket=self.bucket, Key=self._k(key), Body=data, **extra) |
|
|
| def head(self, key: str) -> BlobInfo: |
| r = self.s3.head_object(Bucket=self.bucket, Key=self._k(key)) |
| return BlobInfo(key, size=r.get("ContentLength"), |
| modified=(r.get("LastModified").isoformat() if r.get("LastModified") else None)) |
|
|