| import threading |
| import time |
|
|
|
|
| class TokenBucket: |
| def __init__(self, tpm, timeout=None): |
| self.capacity = int(tpm) |
| self.tokens = 0 |
| self.rate = int(tpm) / 60 |
| self.timeout = timeout |
| self.cond = threading.Condition() |
| self.is_running = True |
| |
| threading.Thread(target=self._generate_tokens).start() |
|
|
| def _generate_tokens(self): |
| """生成令牌""" |
| while self.is_running: |
| with self.cond: |
| if self.tokens < self.capacity: |
| self.tokens += 1 |
| self.cond.notify() |
| time.sleep(1 / self.rate) |
|
|
| def get_token(self): |
| """获取令牌""" |
| with self.cond: |
| while self.tokens <= 0: |
| flag = self.cond.wait(self.timeout) |
| if not flag: |
| return False |
| self.tokens -= 1 |
| return True |
|
|
| def close(self): |
| self.is_running = False |
|
|
|
|
| if __name__ == "__main__": |
| token_bucket = TokenBucket(20, None) |
| |
| for i in range(3): |
| if token_bucket.get_token(): |
| print(f"第{i+1}次请求成功") |
| token_bucket.close() |
|
|