File size: 4,182 Bytes
0edd56d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import boto3
from botocore.exceptions import ClientError
from botocore.client import Config
from abc import ABC, abstractmethod
from fastapi import UploadFile
from app.core.config import settings
import logging

logger = logging.getLogger(__name__)

class StorageProvider(ABC):
    @abstractmethod
    def save(self, file: UploadFile, filename: str) -> str:
        pass
    
    @abstractmethod
    def delete(self, filepath: str) -> bool:
        pass

    @abstractmethod
    def get_presigned_url(self, filepath: str) -> str:
        pass

    @abstractmethod
    def download_to_temp(self, filepath: str) -> str:
        pass

class LocalStorageProvider(StorageProvider):
    def __init__(self, upload_dir: str = "uploads"):
        self.upload_dir = upload_dir
        os.makedirs(self.upload_dir, exist_ok=True)

    def save(self, file: UploadFile, filename: str) -> str:
        filepath = os.path.join(self.upload_dir, filename)
        with open(filepath, "wb") as buffer:
            while chunk := file.file.read(1024 * 1024):
                buffer.write(chunk)
        return filepath
    
    def delete(self, filepath: str) -> bool:
        if os.path.exists(filepath):
            os.remove(filepath)
            return True
        return False

    def get_presigned_url(self, filepath: str) -> str:
        # Local storage doesn't need presigned URLs, returning the path for local streaming
        return filepath

    def download_to_temp(self, filepath: str) -> str:
        # It's already local
        return filepath

class R2StorageProvider(StorageProvider):
    def __init__(self):
        self.s3_client = boto3.client(
            service_name="s3",
            endpoint_url=settings.R2_ENDPOINT_URL,
            aws_access_key_id=settings.R2_ACCESS_KEY_ID,
            aws_secret_access_key=settings.R2_SECRET_ACCESS_KEY,
            region_name="auto",
            config=Config(signature_version="s3v4")
        )
        self.bucket_name = settings.R2_BUCKET_NAME

    def save(self, file: UploadFile, filename: str) -> str:
        r2_key = f"uploads/{filename}"
        file.file.seek(0)
        self.s3_client.upload_fileobj(
            file.file,
            self.bucket_name,
            r2_key,
            ExtraArgs={"ContentType": file.content_type}
        )
        return r2_key
    
    def delete(self, filepath: str) -> bool:
        try:
            self.s3_client.delete_object(Bucket=self.bucket_name, Key=filepath)
            return True
        except ClientError as e:
            logger.error(f"Error deleting file from R2: {e}")
            return False

    def get_presigned_url(self, filepath: str) -> str:
        try:
            url = self.s3_client.generate_presigned_url(
                'get_object',
                Params={'Bucket': self.bucket_name, 'Key': filepath},
                ExpiresIn=3600 # 1 hour expiry
            )
            return url
        except ClientError as e:
            logger.error(f"Error generating presigned url: {e}")
            return filepath

    def download_to_temp(self, filepath: str) -> str:
        import tempfile
        ext = filepath.split(".")[-1] if "." in filepath else "tmp"
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}")
        temp_file.close()
        try:
            self.s3_client.download_file(self.bucket_name, filepath, temp_file.name)
            return temp_file.name
        except ClientError as e:
            logger.error(f"Error downloading file from R2: {e}")
            import os
            if os.path.exists(temp_file.name):
                os.remove(temp_file.name)
            raise

if settings.R2_ENDPOINT_URL and settings.R2_ACCESS_KEY_ID and settings.R2_SECRET_ACCESS_KEY and settings.R2_BUCKET_NAME:
    active_storage: StorageProvider = R2StorageProvider()
    logger.info("Using R2StorageProvider for active_storage")
else:
    active_storage: StorageProvider = LocalStorageProvider()
    logger.info("Using LocalStorageProvider for active_storage (R2 config missing)")