simahu commited on
Commit
899124f
·
verified ·
1 Parent(s): 23013b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -0
app.py CHANGED
@@ -4,12 +4,28 @@ from typing import Optional
4
  from transformers import MarianMTModel, MarianTokenizer
5
  import datetime
6
  import logging
 
 
7
 
8
  logger = logging.getLogger("translate")
9
  logger.setLevel(logging.INFO)
10
 
11
  app = FastAPI(title="翻译服务")
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # 加载翻译模型
14
  MODEL_NAME = "Helsinki-NLP/opus-mt-tc-bible-big-zhx-en"
15
  logger.info(f"{datetime.datetime.now()} Loading model {MODEL_NAME}...")
@@ -26,11 +42,24 @@ class TranslateResponse(BaseModel):
26
 
27
  @app.post("/api/translate", response_model=TranslateResponse)
28
  async def translate(req: TranslateRequest):
 
 
 
 
 
 
 
 
 
 
29
  # tokenizer 会处理编码
30
  batch = tokenizer([req.text], return_tensors="pt", padding=True)
31
  translated = model.generate(**batch)
32
  output = tokenizer.decode(translated[0], skip_special_tokens=True)
33
 
 
 
 
34
  return TranslateResponse(
35
  translated_text=output,
36
  detected_lang=None # 简单翻译版暂不返回检测语言
 
4
  from transformers import MarianMTModel, MarianTokenizer
5
  import datetime
6
  import logging
7
+ import hashlib
8
+ import time
9
 
10
  logger = logging.getLogger("translate")
11
  logger.setLevel(logging.INFO)
12
 
13
  app = FastAPI(title="翻译服务")
14
 
15
+
16
+ cache = {} # {hash: (translated_text, expire_ts)}
17
+
18
+ def _hash_text(text: str) -> str:
19
+ return hashlib.sha256(text.encode("utf-8")).hexdigest()
20
+
21
+ def _clean_cache():
22
+ now = time.time()
23
+ # 清理过期缓存
24
+ expired_keys = [k for k, (_, exp) in cache.items() if exp < now]
25
+ for k in expired_keys:
26
+ del cache[k]
27
+
28
+
29
  # 加载翻译模型
30
  MODEL_NAME = "Helsinki-NLP/opus-mt-tc-bible-big-zhx-en"
31
  logger.info(f"{datetime.datetime.now()} Loading model {MODEL_NAME}...")
 
42
 
43
  @app.post("/api/translate", response_model=TranslateResponse)
44
  async def translate(req: TranslateRequest):
45
+ _clean_cache()
46
+
47
+ h = _hash_text(req.text)
48
+
49
+ # 查缓存
50
+ if h in cache:
51
+ translated_text, expire_ts = cache[h]
52
+ if expire_ts > time.time():
53
+ logger.info(f"Cache hit: {h}")
54
+ return TranslateResponse(translated_text=translated_text)
55
  # tokenizer 会处理编码
56
  batch = tokenizer([req.text], return_tensors="pt", padding=True)
57
  translated = model.generate(**batch)
58
  output = tokenizer.decode(translated[0], skip_special_tokens=True)
59
 
60
+ # 写缓存(保留30分钟)
61
+ cache[h] = (output, time.time() + 30 * 60)
62
+
63
  return TranslateResponse(
64
  translated_text=output,
65
  detected_lang=None # 简单翻译版暂不返回检测语言