superxu520 commited on
Commit
d4a9626
·
1 Parent(s): 4647ed2

fix: 修复速率限制器死锁风险、临时文件清理兼容性和图片目录初始化问题

Browse files
Files changed (2) hide show
  1. app/server/middleware.py +36 -16
  2. app/server/rate_limiter.py +18 -36
app/server/middleware.py CHANGED
@@ -16,13 +16,23 @@ from ..utils import g_config
16
 
17
  # Persistent directory for storing generated images
18
  # Support environment variable override for Docker/HF deployments
19
- IMAGE_STORE_DIR = Path(os.getenv("GEMINI_IMAGE_STORE_PATH", tempfile.gettempdir())) / "ai_generated_images"
20
- IMAGE_STORE_DIR.mkdir(parents=True, exist_ok=True)
21
 
22
 
23
  def get_image_store_dir() -> Path:
24
- """Returns a persistent directory for storing images."""
25
- return IMAGE_STORE_DIR
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def get_image_token(filename: str) -> str:
@@ -47,24 +57,30 @@ def verify_image_token(filename: str, token: str | None) -> bool:
47
 
48
 
49
  def cleanup_expired_images(retention_days: int) -> int:
50
- """Delete images in IMAGE_STORE_DIR older than retention_days."""
51
  if retention_days <= 0:
52
  return 0
53
 
 
54
  now = time.time()
55
  retention_seconds = retention_days * 24 * 60 * 60
56
  cutoff = now - retention_seconds
57
 
58
  count = 0
59
- for file_path in IMAGE_STORE_DIR.iterdir():
60
- if not file_path.is_file():
61
- continue
62
- try:
63
- if file_path.stat().st_mtime < cutoff:
64
- file_path.unlink()
65
- count += 1
66
- except Exception as e:
67
- logger.warning(f"Failed to delete expired image {file_path}: {e}")
 
 
 
 
 
68
 
69
  if count > 0:
70
  logger.info(f"Cleaned up {count} expired images.")
@@ -93,8 +109,12 @@ async def get_temp_dir():
93
  try:
94
  yield Path(temp_dir.name)
95
  finally:
96
- # Run cleanup in thread pool to avoid blocking
97
- await asyncio.get_event_loop().run_in_executor(None, temp_dir.cleanup)
 
 
 
 
98
 
99
 
100
  def verify_api_key(
 
16
 
17
  # Persistent directory for storing generated images
18
  # Support environment variable override for Docker/HF deployments
19
+ # Directory will be created on first access to avoid import-time failures
20
+ _IMAGE_STORE_DIR: Path | None = None
21
 
22
 
23
  def get_image_store_dir() -> Path:
24
+ """Returns a persistent directory for storing images. Creates it if it doesn't exist."""
25
+ global _IMAGE_STORE_DIR
26
+ if _IMAGE_STORE_DIR is None:
27
+ _IMAGE_STORE_DIR = Path(os.getenv("GEMINI_IMAGE_STORE_PATH", tempfile.gettempdir())) / "ai_generated_images"
28
+ try:
29
+ _IMAGE_STORE_DIR.mkdir(parents=True, exist_ok=True)
30
+ except OSError as e:
31
+ logger.error(f"Failed to create image store directory at {_IMAGE_STORE_DIR}: {e}")
32
+ logger.warning("Falling back to system temp directory")
33
+ _IMAGE_STORE_DIR = Path(tempfile.gettempdir()) / "ai_generated_images"
34
+ _IMAGE_STORE_DIR.mkdir(parents=True, exist_ok=True)
35
+ return _IMAGE_STORE_DIR
36
 
37
 
38
  def get_image_token(filename: str) -> str:
 
57
 
58
 
59
  def cleanup_expired_images(retention_days: int) -> int:
60
+ """Delete images in image store directory older than retention_days."""
61
  if retention_days <= 0:
62
  return 0
63
 
64
+ image_store_dir = get_image_store_dir()
65
  now = time.time()
66
  retention_seconds = retention_days * 24 * 60 * 60
67
  cutoff = now - retention_seconds
68
 
69
  count = 0
70
+ try:
71
+ for file_path in image_store_dir.iterdir():
72
+ if not file_path.is_file():
73
+ continue
74
+ try:
75
+ if file_path.stat().st_mtime < cutoff:
76
+ file_path.unlink()
77
+ count += 1
78
+ except Exception as e:
79
+ logger.warning(f"Failed to delete expired image {file_path}: {e}")
80
+ except FileNotFoundError:
81
+ logger.debug(f"Image store directory does not exist yet: {image_store_dir}")
82
+ except Exception as e:
83
+ logger.warning(f"Failed to cleanup expired images: {e}")
84
 
85
  if count > 0:
86
  logger.info(f"Cleaned up {count} expired images.")
 
109
  try:
110
  yield Path(temp_dir.name)
111
  finally:
112
+ # Run cleanup in thread pool to avoid blocking (Python 3.9+)
113
+ try:
114
+ await asyncio.to_thread(temp_dir.cleanup)
115
+ except AttributeError:
116
+ # Fallback for Python < 3.9
117
+ await asyncio.get_running_loop().run_in_executor(None, temp_dir.cleanup)
118
 
119
 
120
  def verify_api_key(
app/server/rate_limiter.py CHANGED
@@ -5,7 +5,6 @@ Protects Gemini API from being overwhelmed by too many concurrent requests.
5
 
6
  import asyncio
7
  import time
8
- from collections import defaultdict
9
  from typing import Callable
10
 
11
  from fastapi import HTTPException, Request, status
@@ -14,7 +13,7 @@ from loguru import logger
14
 
15
  class RateLimiter:
16
  """
17
- Token bucket rate limiter for concurrent requests.
18
 
19
  Limits the number of simultaneous requests being processed.
20
  When limit is exceeded, new requests are queued or rejected.
@@ -30,51 +29,34 @@ class RateLimiter:
30
  """
31
  self.max_concurrent = max_concurrent
32
  self.queue_timeout = queue_timeout
33
- self._current_count = 0
 
34
  self._lock = asyncio.Lock()
35
- self._waiters = 0
36
 
37
  async def acquire(self) -> None:
38
  """
39
  Acquire permission to process a request.
40
  Blocks until a slot is available or timeout.
41
  """
42
- start_time = time.monotonic()
43
-
44
- async with self._lock:
45
- self._waiters += 1
46
- try:
47
- while self._current_count >= self.max_concurrent:
48
- # Check timeout
49
- elapsed = time.monotonic() - start_time
50
- if elapsed >= self.queue_timeout:
51
- raise HTTPException(
52
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
53
- detail="Server is busy. Please try again later.",
54
- )
55
-
56
- # Wait for a slot to become available
57
- self._lock.release()
58
- try:
59
- await asyncio.sleep(0.1) # Small delay to avoid busy waiting
60
- finally:
61
- await self._lock.acquire()
62
-
63
- # Re-check timeout after sleep
64
- if time.monotonic() - start_time >= self.queue_timeout:
65
- raise HTTPException(
66
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
67
- detail="Server is busy. Please try again later.",
68
- )
69
-
70
- self._current_count += 1
71
- finally:
72
- self._waiters -= 1
73
 
74
  async def release(self) -> None:
75
  """Release a slot after request processing completes."""
76
  async with self._lock:
77
- self._current_count -= 1
 
 
78
 
79
 
80
  # Global rate limiter instance
 
5
 
6
  import asyncio
7
  import time
 
8
  from typing import Callable
9
 
10
  from fastapi import HTTPException, Request, status
 
13
 
14
  class RateLimiter:
15
  """
16
+ Semaphore-based rate limiter for concurrent requests.
17
 
18
  Limits the number of simultaneous requests being processed.
19
  When limit is exceeded, new requests are queued or rejected.
 
29
  """
30
  self.max_concurrent = max_concurrent
31
  self.queue_timeout = queue_timeout
32
+ self._semaphore = asyncio.Semaphore(max_concurrent)
33
+ self._acquired_count = 0
34
  self._lock = asyncio.Lock()
 
35
 
36
  async def acquire(self) -> None:
37
  """
38
  Acquire permission to process a request.
39
  Blocks until a slot is available or timeout.
40
  """
41
+ try:
42
+ # Use asyncio.wait_for to implement timeout
43
+ await asyncio.wait_for(self._semaphore.acquire(), timeout=self.queue_timeout)
44
+ async with self._lock:
45
+ self._acquired_count += 1
46
+ logger.debug(f"Rate limiter: acquired slot ({self._acquired_count}/{self.max_concurrent})")
47
+ except asyncio.TimeoutError:
48
+ logger.warning(f"Rate limiter: request queued for {self.queue_timeout}s, rejecting")
49
+ raise HTTPException(
50
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
51
+ detail="Server is busy. Please try again later.",
52
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  async def release(self) -> None:
55
  """Release a slot after request processing completes."""
56
  async with self._lock:
57
+ self._acquired_count -= 1
58
+ logger.debug(f"Rate limiter: released slot ({self._acquired_count}/{self.max_concurrent})")
59
+ self._semaphore.release()
60
 
61
 
62
  # Global rate limiter instance