mhdzumair commited on
Commit
0424d2b
·
1 Parent(s): dd06f9d

Add hybrid caching system for better performance & fix limitation on multi workers

Browse files

Replaced TTLCache with an optimized hybrid cache combining memory and file storage for better performance and scalability. Removed cachetools dependency and updated pyproject.toml to include aiofiles. Adjusted speed test service to use the new caching mechanism.

mediaflow_proxy/speedtest/service.py CHANGED
@@ -1,17 +1,17 @@
1
  import logging
2
  import time
3
- from datetime import datetime
4
  from typing import Dict, Optional, Type
5
 
6
- from cachetools import TTLCache
7
  from httpx import AsyncClient
8
 
 
9
  from mediaflow_proxy.utils.http_utils import Streamer
 
10
  from .models import SpeedTestTask, LocationResult, SpeedTestResult, SpeedTestProvider
11
  from .providers.all_debrid import AllDebridSpeedTest
12
  from .providers.base import BaseSpeedTestProvider
13
  from .providers.real_debrid import RealDebridSpeedTest
14
- from ..configs import settings
15
 
16
  logger = logging.getLogger(__name__)
17
 
@@ -20,9 +20,6 @@ class SpeedTestService:
20
  """Service for managing speed tests across different providers."""
21
 
22
  def __init__(self):
23
- # Cache for speed test results (1 hour TTL)
24
- self._cache: TTLCache[str, SpeedTestTask] = TTLCache(maxsize=100, ttl=3600)
25
-
26
  # Provider mapping
27
  self._providers: Dict[SpeedTestProvider, Type[BaseSpeedTestProvider]] = {
28
  SpeedTestProvider.REAL_DEBRID: RealDebridSpeedTest,
@@ -49,19 +46,22 @@ class SpeedTestService:
49
  # Get initial URLs and user info
50
  urls, user_info = await provider_impl.get_test_urls()
51
 
52
- task = SpeedTestTask(task_id=task_id, provider=provider, started_at=datetime.utcnow(), user_info=user_info)
 
 
53
 
54
- self._cache[task_id] = task
55
  return task
56
 
57
- async def get_test_results(self, task_id: str) -> Optional[SpeedTestTask]:
 
58
  """Get results for a specific task."""
59
- return self._cache.get(task_id)
60
 
61
  async def run_speedtest(self, task_id: str, provider: SpeedTestProvider, api_key: Optional[str] = None):
62
  """Run the speed test with real-time updates."""
63
  try:
64
- task = self._cache.get(task_id)
65
  if not task:
66
  raise ValueError(f"Task {task_id} not found")
67
 
@@ -74,27 +74,28 @@ class SpeedTestService:
74
  for location, url in config.test_urls.items():
75
  try:
76
  task.current_location = location
 
77
  result = await self._test_location(location, url, streamer, config.test_duration, provider_impl)
78
  task.results[location] = result
79
- self._cache[task_id] = task
80
  except Exception as e:
81
  logger.error(f"Error testing {location}: {str(e)}")
82
  task.results[location] = LocationResult(
83
  error=str(e), server_name=location, server_url=config.test_urls[location]
84
  )
85
- self._cache[task_id] = task
86
 
87
  # Mark task as completed
88
- task.completed_at = datetime.utcnow()
89
  task.status = "completed"
90
  task.current_location = None
91
- self._cache[task_id] = task
92
 
93
  except Exception as e:
94
  logger.error(f"Error in speed test task {task_id}: {str(e)}")
95
- if task := self._cache.get(task_id):
96
  task.status = "failed"
97
- self._cache[task_id] = task
98
 
99
  async def _test_location(
100
  self, location: str, url: str, streamer: Streamer, test_duration: int, provider: BaseSpeedTestProvider
 
1
  import logging
2
  import time
3
+ from datetime import datetime, timezone
4
  from typing import Dict, Optional, Type
5
 
 
6
  from httpx import AsyncClient
7
 
8
+ from mediaflow_proxy.configs import settings
9
  from mediaflow_proxy.utils.http_utils import Streamer
10
+ from mediaflow_proxy.utils.cache_utils import get_cached_speedtest, set_cache_speedtest
11
  from .models import SpeedTestTask, LocationResult, SpeedTestResult, SpeedTestProvider
12
  from .providers.all_debrid import AllDebridSpeedTest
13
  from .providers.base import BaseSpeedTestProvider
14
  from .providers.real_debrid import RealDebridSpeedTest
 
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
20
  """Service for managing speed tests across different providers."""
21
 
22
  def __init__(self):
 
 
 
23
  # Provider mapping
24
  self._providers: Dict[SpeedTestProvider, Type[BaseSpeedTestProvider]] = {
25
  SpeedTestProvider.REAL_DEBRID: RealDebridSpeedTest,
 
46
  # Get initial URLs and user info
47
  urls, user_info = await provider_impl.get_test_urls()
48
 
49
+ task = SpeedTestTask(
50
+ task_id=task_id, provider=provider, started_at=datetime.now(tz=timezone.utc), user_info=user_info
51
+ )
52
 
53
+ await set_cache_speedtest(task_id, task)
54
  return task
55
 
56
+ @staticmethod
57
+ async def get_test_results(task_id: str) -> Optional[SpeedTestTask]:
58
  """Get results for a specific task."""
59
+ return await get_cached_speedtest(task_id)
60
 
61
  async def run_speedtest(self, task_id: str, provider: SpeedTestProvider, api_key: Optional[str] = None):
62
  """Run the speed test with real-time updates."""
63
  try:
64
+ task = await get_cached_speedtest(task_id)
65
  if not task:
66
  raise ValueError(f"Task {task_id} not found")
67
 
 
74
  for location, url in config.test_urls.items():
75
  try:
76
  task.current_location = location
77
+ await set_cache_speedtest(task_id, task)
78
  result = await self._test_location(location, url, streamer, config.test_duration, provider_impl)
79
  task.results[location] = result
80
+ await set_cache_speedtest(task_id, task)
81
  except Exception as e:
82
  logger.error(f"Error testing {location}: {str(e)}")
83
  task.results[location] = LocationResult(
84
  error=str(e), server_name=location, server_url=config.test_urls[location]
85
  )
86
+ await set_cache_speedtest(task_id, task)
87
 
88
  # Mark task as completed
89
+ task.completed_at = datetime.now(tz=timezone.utc)
90
  task.status = "completed"
91
  task.current_location = None
92
+ await set_cache_speedtest(task_id, task)
93
 
94
  except Exception as e:
95
  logger.error(f"Error in speed test task {task_id}: {str(e)}")
96
+ if task := await get_cached_speedtest(task_id):
97
  task.status = "failed"
98
+ await set_cache_speedtest(task_id, task)
99
 
100
  async def _test_location(
101
  self, location: str, url: str, streamer: Streamer, test_duration: int, provider: BaseSpeedTestProvider
mediaflow_proxy/utils/cache_utils.py CHANGED
@@ -1,16 +1,346 @@
1
- import datetime
 
2
  import logging
 
 
 
 
 
 
 
 
 
3
 
4
- from cachetools import TTLCache
 
 
5
 
6
- from .http_utils import download_file_with_retry
7
- from .mpd_utils import parse_mpd, parse_mpd_dict
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
- # cache dictionary
12
- mpd_cache = TTLCache(maxsize=100, ttl=300) # 5 minutes default TTL
13
- init_segment_cache = TTLCache(maxsize=100, ttl=3600) # 1 hour default TTL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  async def get_cached_mpd(
@@ -20,54 +350,49 @@ async def get_cached_mpd(
20
  parse_segment_profile_id: str | None = None,
21
  verify_ssl: bool = True,
22
  use_request_proxy: bool = True,
23
- ) -> dict:
24
- """
25
- Retrieves and caches the MPD manifest, parsing it if not already cached.
26
-
27
- Args:
28
- mpd_url (str): The URL of the MPD manifest.
29
- headers (dict): The headers to include in the request.
30
- parse_drm (bool): Whether to parse DRM information.
31
- parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
32
- verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
33
- use_request_proxy (bool, optional): Whether to use the proxy configuration from the user's MediaFlow config. Defaults to True.
34
-
35
- Returns:
36
- dict: The parsed MPD manifest data.
37
- """
38
- current_time = datetime.datetime.now(datetime.UTC)
39
- if mpd_url in mpd_cache and mpd_cache[mpd_url]["expires"] > current_time:
40
- logger.info(f"Using cached MPD for {mpd_url}")
41
- return parse_mpd_dict(mpd_cache[mpd_url]["mpd"], mpd_url, parse_drm, parse_segment_profile_id)
42
-
43
- mpd_dict = parse_mpd(
44
- await download_file_with_retry(mpd_url, headers, verify_ssl=verify_ssl, use_request_proxy=use_request_proxy)
45
- )
46
- parsed_mpd_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
47
- current_time = datetime.datetime.now(datetime.UTC)
48
- expiration_time = current_time + datetime.timedelta(seconds=parsed_mpd_dict.get("minimumUpdatePeriod", 300))
49
- mpd_cache[mpd_url] = {"mpd": mpd_dict, "expires": expiration_time}
50
- return parsed_mpd_dict
51
-
52
 
53
- async def get_cached_init_segment(
54
- init_url: str, headers: dict, verify_ssl: bool = True, use_request_proxy: bool = True
55
- ) -> bytes:
56
- """
57
- Retrieves and caches the initialization segment.
58
-
59
- Args:
60
- init_url (str): The URL of the initialization segment.
61
- headers (dict): The headers to include in the request.
62
- verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
63
- use_request_proxy (bool, optional): Whether to use the proxy configuration from the user's MediaFlow config. Defaults to True.
64
-
65
- Returns:
66
- bytes: The initialization segment content.
67
- """
68
- if init_url not in init_segment_cache:
69
- init_content = await download_file_with_retry(
70
- init_url, headers, verify_ssl=verify_ssl, use_request_proxy=use_request_proxy
71
  )
72
- init_segment_cache[init_url] = init_content
73
- return init_segment_cache[init_url]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
  import logging
4
+ import os
5
+ import tempfile
6
+ import threading
7
+ import time
8
+ from collections import OrderedDict
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Optional, Union, Any
13
 
14
+ import aiofiles
15
+ import aiofiles.os
16
+ from pydantic import ValidationError
17
 
18
+ from mediaflow_proxy.speedtest.models import SpeedTestTask
19
+ from mediaflow_proxy.utils.http_utils import download_file_with_retry
20
+ from mediaflow_proxy.utils.mpd_utils import parse_mpd, parse_mpd_dict
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
+
25
+ @dataclass
26
+ class CacheEntry:
27
+ """Represents a cache entry with metadata."""
28
+
29
+ data: bytes
30
+ expires_at: float
31
+ access_count: int = 0
32
+ last_access: float = 0.0
33
+ size: int = 0
34
+
35
+
36
+ class CacheStats:
37
+ """Tracks cache performance metrics."""
38
+
39
+ def __init__(self):
40
+ self.hits = 0
41
+ self.misses = 0
42
+ self.memory_hits = 0
43
+ self.disk_hits = 0
44
+ self._lock = threading.Lock()
45
+
46
+ def record_hit(self, from_memory: bool):
47
+ with self._lock:
48
+ self.hits += 1
49
+ if from_memory:
50
+ self.memory_hits += 1
51
+ else:
52
+ self.disk_hits += 1
53
+
54
+ def record_miss(self):
55
+ with self._lock:
56
+ self.misses += 1
57
+
58
+ @property
59
+ def hit_rate(self) -> float:
60
+ total = self.hits + self.misses
61
+ return self.hits / total if total > 0 else 0.0
62
+
63
+ def __str__(self) -> str:
64
+ return (
65
+ f"Cache Stats: Hits={self.hits} (Memory: {self.memory_hits}, "
66
+ f"Disk: {self.disk_hits}), Misses={self.misses}, "
67
+ f"Hit Rate={self.hit_rate:.2%}"
68
+ )
69
+
70
+
71
+ class AsyncLRUMemoryCache:
72
+ """Thread-safe LRU memory cache with async support."""
73
+
74
+ def __init__(self, maxsize: int):
75
+ self.maxsize = maxsize
76
+ self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
77
+ self._lock = threading.Lock()
78
+ self._current_size = 0
79
+
80
+ def get(self, key: str) -> Optional[CacheEntry]:
81
+ with self._lock:
82
+ if key in self._cache:
83
+ entry = self._cache.pop(key) # Remove and re-insert for LRU
84
+ if time.time() < entry.expires_at:
85
+ entry.access_count += 1
86
+ entry.last_access = time.time()
87
+ self._cache[key] = entry
88
+ return entry
89
+ else:
90
+ # Remove expired entry
91
+ self._current_size -= entry.size
92
+ del self._cache[key]
93
+ return None
94
+
95
+ def set(self, key: str, entry: CacheEntry) -> None:
96
+ with self._lock:
97
+ if key in self._cache:
98
+ old_entry = self._cache[key]
99
+ self._current_size -= old_entry.size
100
+
101
+ # Check if we need to make space
102
+ while self._current_size + entry.size > self.maxsize and self._cache:
103
+ _, removed_entry = self._cache.popitem(last=False)
104
+ self._current_size -= removed_entry.size
105
+
106
+ self._cache[key] = entry
107
+ self._current_size += entry.size
108
+
109
+ def remove(self, key: str) -> None:
110
+ with self._lock:
111
+ if key in self._cache:
112
+ entry = self._cache.pop(key)
113
+ self._current_size -= entry.size
114
+
115
+
116
+ class OptimizedHybridCache:
117
+ """High-performance hybrid cache combining memory and file storage."""
118
+
119
+ def __init__(
120
+ self,
121
+ cache_dir_name: str,
122
+ ttl: int,
123
+ max_memory_size: int = 100 * 1024 * 1024, # 100MB default
124
+ file_shards: int = 256, # Number of subdirectories for sharding
125
+ executor_workers: int = 4,
126
+ ):
127
+ self.cache_dir = Path(tempfile.gettempdir()) / cache_dir_name
128
+ self.ttl = ttl
129
+ self.file_shards = file_shards
130
+ self.memory_cache = AsyncLRUMemoryCache(maxsize=max_memory_size)
131
+ self.stats = CacheStats()
132
+ self._executor = ThreadPoolExecutor(max_workers=executor_workers)
133
+ self._lock = asyncio.Lock()
134
+
135
+ # Initialize cache directories
136
+ self._init_cache_dirs()
137
+
138
+ def _init_cache_dirs(self):
139
+ """Initialize sharded cache directories."""
140
+ for i in range(self.file_shards):
141
+ shard_dir = self.cache_dir / f"shard_{i:03d}"
142
+ os.makedirs(shard_dir, exist_ok=True)
143
+
144
+ def _get_shard_path(self, key: str) -> Path:
145
+ """Get the appropriate shard directory for a key."""
146
+ shard_num = hash(key) % self.file_shards
147
+ return self.cache_dir / f"shard_{shard_num:03d}"
148
+
149
+ def _get_file_path(self, key: str) -> Path:
150
+ """Get the file path for a cache key."""
151
+ safe_key = str(hash(key))
152
+ return self._get_shard_path(key) / safe_key
153
+
154
+ async def get(self, key: str, default: Any = None) -> Optional[bytes]:
155
+ """
156
+ Get value from cache, trying memory first then file.
157
+
158
+ Args:
159
+ key: Cache key
160
+ default: Default value if key not found
161
+
162
+ Returns:
163
+ Cached value or default if not found
164
+ """
165
+ # Try memory cache first
166
+ entry = self.memory_cache.get(key)
167
+ if entry is not None:
168
+ self.stats.record_hit(from_memory=True)
169
+ return entry.data
170
+
171
+ # Try file cache
172
+ try:
173
+ file_path = self._get_file_path(key)
174
+ async with aiofiles.open(file_path, "rb") as f:
175
+ metadata_size = await f.read(8)
176
+ metadata_length = int.from_bytes(metadata_size, "big")
177
+ metadata_bytes = await f.read(metadata_length)
178
+ metadata = json.loads(metadata_bytes.decode())
179
+
180
+ # Check expiration
181
+ if metadata["expires_at"] < time.time():
182
+ await self.delete(key)
183
+ self.stats.record_miss()
184
+ return default
185
+
186
+ # Read data
187
+ data = await f.read()
188
+
189
+ # Update memory cache in background
190
+ entry = CacheEntry(
191
+ data=data,
192
+ expires_at=metadata["expires_at"],
193
+ access_count=metadata["access_count"] + 1,
194
+ last_access=time.time(),
195
+ size=len(data),
196
+ )
197
+ self.memory_cache.set(key, entry)
198
+
199
+ self.stats.record_hit(from_memory=False)
200
+ return data
201
+
202
+ except FileNotFoundError:
203
+ self.stats.record_miss()
204
+ return default
205
+ except Exception as e:
206
+ logger.error(f"Error reading from cache: {e}")
207
+ self.stats.record_miss()
208
+ return default
209
+
210
+ async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
211
+ """
212
+ Set value in both memory and file cache.
213
+
214
+ Args:
215
+ key: Cache key
216
+ data: Data to cache
217
+ ttl: Optional TTL override
218
+
219
+ Returns:
220
+ bool: Success status
221
+ """
222
+ if not isinstance(data, (bytes, bytearray, memoryview)):
223
+ raise ValueError("Data must be bytes, bytearray, or memoryview")
224
+
225
+ expires_at = time.time() + (ttl or self.ttl)
226
+
227
+ # Create cache entry
228
+ entry = CacheEntry(data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data))
229
+
230
+ # Update memory cache
231
+ self.memory_cache.set(key, entry)
232
+
233
+ # Update file cache
234
+ try:
235
+ file_path = self._get_file_path(key)
236
+ temp_path = file_path.with_suffix(".tmp")
237
+
238
+ metadata = {"expires_at": expires_at, "access_count": 0, "last_access": time.time()}
239
+ metadata_bytes = json.dumps(metadata).encode()
240
+ metadata_size = len(metadata_bytes).to_bytes(8, "big")
241
+
242
+ async with aiofiles.open(temp_path, "wb") as f:
243
+ await f.write(metadata_size)
244
+ await f.write(metadata_bytes)
245
+ await f.write(data)
246
+
247
+ await aiofiles.os.rename(temp_path, file_path)
248
+ return True
249
+
250
+ except Exception as e:
251
+ logger.error(f"Error writing to cache: {e}")
252
+ try:
253
+ await aiofiles.os.remove(temp_path)
254
+ except:
255
+ pass
256
+ return False
257
+
258
+ async def delete(self, key: str) -> bool:
259
+ """Delete item from both caches."""
260
+ self.memory_cache.remove(key)
261
+
262
+ try:
263
+ file_path = self._get_file_path(key)
264
+ await aiofiles.os.remove(file_path)
265
+ return True
266
+ except FileNotFoundError:
267
+ return True
268
+ except Exception as e:
269
+ logger.error(f"Error deleting from cache: {e}")
270
+ return False
271
+
272
+ async def cleanup_expired(self):
273
+ """Clean up expired cache entries."""
274
+ current_time = time.time()
275
+
276
+ async def check_and_clean_file(file_path: Path):
277
+ try:
278
+ async with aiofiles.open(file_path, "rb") as f:
279
+ metadata_size = await f.read(8)
280
+ metadata_length = int.from_bytes(metadata_size, "big")
281
+ metadata_bytes = await f.read(metadata_length)
282
+ metadata = json.loads(metadata_bytes.decode())
283
+
284
+ if metadata["expires_at"] < current_time:
285
+ await aiofiles.os.remove(file_path)
286
+ except Exception as e:
287
+ logger.error(f"Error cleaning up file {file_path}: {e}")
288
+
289
+ # Clean up each shard
290
+ for i in range(self.file_shards):
291
+ shard_dir = self.cache_dir / f"shard_{i:03d}"
292
+ try:
293
+ async for entry in aiofiles.os.scandir(shard_dir):
294
+ if entry.is_file() and not entry.name.endswith(".tmp"):
295
+ await check_and_clean_file(Path(entry.path))
296
+ except Exception as e:
297
+ logger.error(f"Error scanning shard directory {shard_dir}: {e}")
298
+
299
+
300
+ # Create cache instances
301
+ INIT_SEGMENT_CACHE = OptimizedHybridCache(
302
+ cache_dir_name="init_segment_cache",
303
+ ttl=3600, # 1 hour
304
+ max_memory_size=500 * 1024 * 1024, # 500MB for init segments
305
+ file_shards=512, # More shards for better distribution
306
+ )
307
+
308
+ MPD_CACHE = OptimizedHybridCache(
309
+ cache_dir_name="mpd_cache",
310
+ ttl=300, # 5 minutes
311
+ max_memory_size=100 * 1024 * 1024, # 100MB for MPD files
312
+ file_shards=128,
313
+ )
314
+
315
+ SPEEDTEST_CACHE = OptimizedHybridCache(
316
+ cache_dir_name="speedtest_cache",
317
+ ttl=3600, # 1 hour
318
+ max_memory_size=50 * 1024 * 1024, # 50MB for speed test results
319
+ file_shards=64,
320
+ )
321
+
322
+
323
+ # Specific cache implementations
324
+ async def get_cached_init_segment(
325
+ init_url: str, headers: dict, verify_ssl: bool = True, use_request_proxy: bool = True
326
+ ) -> Optional[bytes]:
327
+ """Get initialization segment from cache or download it."""
328
+ # Try cache first
329
+ cached_data = await INIT_SEGMENT_CACHE.get(init_url)
330
+ if cached_data is not None:
331
+ return cached_data
332
+
333
+ # Download if not cached
334
+ try:
335
+ init_content = await download_file_with_retry(
336
+ init_url, headers, verify_ssl=verify_ssl, use_request_proxy=use_request_proxy
337
+ )
338
+ if init_content:
339
+ await INIT_SEGMENT_CACHE.set(init_url, init_content)
340
+ return init_content
341
+ except Exception as e:
342
+ logger.error(f"Error downloading init segment: {e}")
343
+ return None
344
 
345
 
346
  async def get_cached_mpd(
 
350
  parse_segment_profile_id: str | None = None,
351
  verify_ssl: bool = True,
352
  use_request_proxy: bool = True,
353
+ ) -> Optional[dict]:
354
+ """Get MPD from cache or download and parse it."""
355
+ # Try cache first
356
+ cached_data = await MPD_CACHE.get(mpd_url)
357
+ if cached_data is not None:
358
+ try:
359
+ mpd_dict = json.loads(cached_data)
360
+ return parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
361
+ except json.JSONDecodeError:
362
+ await MPD_CACHE.delete(mpd_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
+ # Download and parse if not cached
365
+ try:
366
+ mpd_content = await download_file_with_retry(
367
+ mpd_url, headers, verify_ssl=verify_ssl, use_request_proxy=use_request_proxy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  )
369
+ mpd_dict = parse_mpd(mpd_content)
370
+ parsed_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
371
+
372
+ # Cache the original MPD dict
373
+ await MPD_CACHE.set(mpd_url, json.dumps(mpd_dict).encode())
374
+ return parsed_dict
375
+ except Exception as e:
376
+ logger.error(f"Error processing MPD: {e}")
377
+ return None
378
+
379
+
380
+ async def get_cached_speedtest(task_id: str) -> Optional[SpeedTestTask]:
381
+ """Get speed test results from cache."""
382
+ cached_data = await SPEEDTEST_CACHE.get(task_id)
383
+ if cached_data is not None:
384
+ try:
385
+ return SpeedTestTask.model_validate_json(cached_data.decode())
386
+ except ValidationError as e:
387
+ logger.error(f"Error parsing cached speed test data: {e}")
388
+ await SPEEDTEST_CACHE.delete(task_id)
389
+ return None
390
+
391
+
392
+ async def set_cache_speedtest(task_id: str, task: SpeedTestTask) -> bool:
393
+ """Cache speed test results."""
394
+ try:
395
+ return await SPEEDTEST_CACHE.set(task_id, task.model_dump_json().encode())
396
+ except Exception as e:
397
+ logger.error(f"Error caching speed test data: {e}")
398
+ return False
poetry.lock CHANGED
@@ -1,5 +1,16 @@
1
  # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
2
 
 
 
 
 
 
 
 
 
 
 
 
3
  [[package]]
4
  name = "annotated-types"
5
  version = "0.7.0"
@@ -79,17 +90,6 @@ d = ["aiohttp (>=3.10)"]
79
  jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
80
  uvloop = ["uvloop (>=0.15.2)"]
81
 
82
- [[package]]
83
- name = "cachetools"
84
- version = "5.5.0"
85
- description = "Extensible memoizing collections and decorators"
86
- optional = false
87
- python-versions = ">=3.7"
88
- files = [
89
- {file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"},
90
- {file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"},
91
- ]
92
-
93
  [[package]]
94
  name = "certifi"
95
  version = "2024.8.30"
@@ -631,4 +631,4 @@ files = [
631
  [metadata]
632
  lock-version = "2.0"
633
  python-versions = ">=3.10"
634
- content-hash = "b9b0a0539cd08dd58b46d7fc41e940666ae658a68384c19a2fb692001aa71120"
 
1
  # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
2
 
3
+ [[package]]
4
+ name = "aiofiles"
5
+ version = "24.1.0"
6
+ description = "File support for asyncio."
7
+ optional = false
8
+ python-versions = ">=3.8"
9
+ files = [
10
+ {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"},
11
+ {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"},
12
+ ]
13
+
14
  [[package]]
15
  name = "annotated-types"
16
  version = "0.7.0"
 
90
  jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
91
  uvloop = ["uvloop (>=0.15.2)"]
92
 
 
 
 
 
 
 
 
 
 
 
 
93
  [[package]]
94
  name = "certifi"
95
  version = "2024.8.30"
 
631
  [metadata]
632
  lock-version = "2.0"
633
  python-versions = ">=3.10"
634
+ content-hash = "31ae30f007ef7dacc5a13f41d1f88f3d2112e10e72846df59cb0956593bb33b9"
pyproject.toml CHANGED
@@ -27,12 +27,12 @@ fastapi = "0.115.4"
27
  httpx = {extras = ["socks"], version = "^0.27.2"}
28
  tenacity = "^9.0.0"
29
  xmltodict = "^0.14.2"
30
- cachetools = "^5.4.0"
31
  pydantic-settings = "^2.6.1"
32
  gunicorn = "^23.0.0"
33
  pycryptodome = "^3.20.0"
34
  uvicorn = "^0.32.0"
35
  tqdm = "^4.67.0"
 
36
 
37
 
38
  [tool.poetry.group.dev.dependencies]
 
27
  httpx = {extras = ["socks"], version = "^0.27.2"}
28
  tenacity = "^9.0.0"
29
  xmltodict = "^0.14.2"
 
30
  pydantic-settings = "^2.6.1"
31
  gunicorn = "^23.0.0"
32
  pycryptodome = "^3.20.0"
33
  uvicorn = "^0.32.0"
34
  tqdm = "^4.67.0"
35
+ aiofiles = "^24.1.0"
36
 
37
 
38
  [tool.poetry.group.dev.dependencies]