File size: 7,205 Bytes
6f79f8b
 
 
 
 
 
a31c678
6f79f8b
 
 
0c4649e
 
6f79f8b
 
 
 
 
 
 
 
 
f34f359
 
6f79f8b
 
 
 
 
 
 
 
 
 
 
 
 
a7c7fdb
6f79f8b
 
 
 
 
 
 
 
0c4649e
6f79f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16186a1
 
 
 
 
 
 
 
 
6f79f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import time
import threading
import logging
import boto3
from botocore.exceptions import ClientError, NoCredentialsError, PartialCredentialsError
import os
import re

from dotenv import load_dotenv

from app.exceptions import DBError

logger = logging.getLogger(__name__)

load_dotenv()

AWS_ACCESS_KEY = os.getenv('AWS_ACCESS_KEY')
AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
REGION = os.getenv('AWS_REGION')

def upload_file_to_s3(filename):
    # Extract user_id from the filename
    user_id = os.path.basename(filename).split('.')[0]
    # user_id = filename.split('.')[0]
    print(user_id)
    function_name = upload_file_to_s3.__name__
    logger.info(f"Uploading file {filename} to S3", extra={'user_id': user_id, 'endpoint': function_name})
    bucket = 'core-ai-assets'
    try:
        if (AWS_ACCESS_KEY and AWS_SECRET_KEY):
            session = boto3.session.Session(aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY, region_name=REGION)
        else:
            session = boto3.session.Session()
        s3_client = session.client('s3')
        with open(filename, "rb") as f:
            ## Upload to Production Folder
            s3_client.upload_fileobj(f, bucket, f'dev/users/{user_id}.pkl')
        logger.info(f"File {filename} uploaded successfully to S3", extra={'user_id': user_id, 'endpoint': function_name})
        
        os.remove(filename)

        # force_file_move(os.path.join('users', 'to_upload', filename), os.path.join('users', 'data', filename))
        return True
    except (FileNotFoundError, NoCredentialsError, PartialCredentialsError) as e:
        logger.error(f"S3 upload failed for {filename}: {e}", extra={'user_id': user_id, 'endpoint': function_name})
        raise DBError(user_id, "S3Error", f"Failed to upload file {filename} to S3", e)

class CustomTTLCache:
    def __init__(self, ttl=60, cleanup_interval=10):
        """
        ttl: Time to live (in seconds) for each item
        cleanup_interval: Interval at which background thread checks for expired items
        """
        self.ttl = ttl
        self.data = {}  # key -> (value, expiration_time)
        self.lock = threading.Lock()
        self.cleanup_interval = cleanup_interval
        self.running = True
        self.cleanup_thread = threading.Thread(target=self._cleanup_task, daemon=True) # for periodic cleanup
        self.cleanup_thread.start()

    def __setitem__(self, key, value):
        """
        For item assignment i.e. cache[key] = value
        """
        self.set(key, value)

    def set(self, key, value):
        # set expiration time as current time + ttl
        expire_at = time.time() + self.ttl
        with self.lock:
            self.data[key] = (value, expire_at)

    def get(self, key, default=None):
        """
        Get a value from the cache. Returns default if key not found or expired.
        Resets the TTL if the item is successfully accessed.
        """
        with self.lock:
            entry = self.data.get(key)
            if not entry:
                return default
            value, expire_at = entry
            now = time.time()
            if now > expire_at:
                # item expired, evict and upload
                self._evict_item(key, value)
                return default
            # refresh this data's TTL since it's accessed
            self.data[key] = (value, now + self.ttl)
            return value

    def __contains__(self, key):
        """
        Chheck if a key is in the cache and not expired.
        If not expired, reset the TTL.
        """
        with self.lock:
            entry = self.data.get(key)
            if not entry:
                return False
            value, expire_at = entry
            now = time.time()
            if now > expire_at:
                self._evict_item(key, value)
                return False
            # refresh TTL
            self.data[key] = (value, now + self.ttl)
            return True

    def __getitem__(self, key):
        """
        Retrieve an item i.e. user_cache[key].
        Raises KeyError if not found or expired.
        Resets the TTL if the item is successfully accessed.
        """
        with self.lock:
            if key not in self.data:
                raise KeyError(key)
            value, expire_at = self.data[key]
            now = time.time()
            if now > expire_at:
                self._evict_item(key, value)
                raise KeyError(key)
            self.data[key] = (value, now + self.ttl)
            return value
        
    def reset_cache(self):
        """
        Reset the cache by removing all items.
        """
        with self.lock:
            for key, value in self.data.items():
                self._evict_item(key, value)
            self.data = {}

    def pop(self, key, default=None):
        """
        Remove an item from the cache and return its value.
        Upload to S3 if it hasn't expired yet.
        """
        with self.lock:
            entry = self.data.pop(key, None)
            if entry:
                value, expire_at = entry
                # upload before removal
                self._upload_value(key, value)
                return value
            return default

    def close(self):
        """
        Stop the background cleanup thread. 
        Technically we should call this when done using the cache but I guess we are never done using it lol.
        """
        self.running = False
        self.cleanup_thread.join()

    def _cleanup_task(self):
        """
        Background task that periodically checks for expired items and removes them.
        """
        while self.running:
            now = time.time()
            keys_to_remove = []
            with self.lock:
                for k, (v, exp) in list(self.data.items()):
                    if now > exp:
                        keys_to_remove.append(k)
            
            # Evict expired items outside the lock to handle uploading
            for k in keys_to_remove:
                with self.lock:
                    if k in self.data:
                        value, _ = self.data.pop(k)
                        self._evict_item(k, value)

            time.sleep(self.cleanup_interval)

    def _evict_item(self, key, value):
        """
        Called internally when an item expires. upload to S3 before removal.
        """
        self._upload_value(key, value)

    def _upload_value(self, key, value):
        """
        Saves the user's data and uploads the resulting file to S3.
        """
        try:
            filename = value.save_user()
            logger.info(
                f"FILENAME = {filename}",
                extra={'user_id': key, 'endpoint': 'cache_eviction'}
            )
            upload_file_to_s3(filename)
            logger.info(
                f"User {key} saved to S3 during cache eviction",
                extra={'user_id': key, 'endpoint': 'cache_eviction'}
            )
        except Exception as e:
            logger.error(
                f"Failed to save user {key} to S3 during cache eviction: {e}",
                extra={'user_id': key, 'endpoint': 'cache_eviction'}
            )