danicor commited on
Commit
eda5854
·
verified ·
1 Parent(s): 19b3421

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ from fastapi import FastAPI, HTTPException, Depends
3
+ from pydantic import BaseModel
4
+ from typing import Optional, Dict
5
+ import redis
6
+ import hashlib
7
+ import json
8
+ import torch
9
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
10
+ import asyncio
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from contextlib import asynccontextmanager
13
+
14
+ # Configuration
15
+ CACHE_TTL = 3600 # 1 hour default
16
+ REDIS_URL = "redis://localhost:6379"
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # Global model and tokenizer
20
+ model = None
21
+ tokenizer = None
22
+
23
+ @asynccontextmanager
24
+ async def lifespan(app: FastAPI):
25
+ # Load model on startup
26
+ global model, tokenizer
27
+ model_name = "Helsinki-NLP/opus-mt-mul-en" # مدل چندزبانه مثال
28
+ print(f"Loading model on {DEVICE}...")
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(DEVICE)
31
+ print("Model loaded successfully")
32
+ yield
33
+ # Cleanup on shutdown
34
+ if model:
35
+ del model
36
+ if tokenizer:
37
+ del tokenizer
38
+
39
+ app = FastAPI(lifespan=lifespan)
40
+ redis_client = redis.Redis.from_url(REDIS_URL, decode_responses=True)
41
+ executor = ThreadPoolExecutor(max_workers=4)
42
+
43
+ class TranslationRequest(BaseModel):
44
+ text: str
45
+ source_lang: str
46
+ target_lang: str
47
+
48
+ class TranslationResponse(BaseModel):
49
+ translated_text: str
50
+ from_cache: bool
51
+ character_count: int
52
+
53
+ def generate_cache_key(text: str, source_lang: str, target_lang: str) -> str:
54
+ """Generate unique cache key"""
55
+ key_str = f"{text}_{source_lang}_{target_lang}"
56
+ return hashlib.md5(key_str.encode()).hexdigest()
57
+
58
+ def translate_text(text: str, source_lang: str, target_lang: str) -> str:
59
+ """Perform translation using Hugging Face model"""
60
+ # Prepare text for translation based on model requirements
61
+ if source_lang != "en":
62
+ text = f">>{target_lang}<< {text}"
63
+
64
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
65
+
66
+ with torch.no_grad():
67
+ outputs = model.generate(**inputs, max_length=512)
68
+
69
+ translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
70
+ return translated_text
71
+
72
+ @app.post("/translate", response_model=TranslationResponse)
73
+ async def translate(request: TranslationRequest):
74
+ # Check cache first
75
+ cache_key = generate_cache_key(request.text, request.source_lang, request.target_lang)
76
+ cached_result = redis_client.get(cache_key)
77
+
78
+ if cached_result:
79
+ return TranslationResponse(
80
+ translated_text=cached_result,
81
+ from_cache=True,
82
+ character_count=len(request.text)
83
+ )
84
+
85
+ # Perform translation
86
+ try:
87
+ # Run translation in thread pool to avoid blocking
88
+ translated_text = await asyncio.get_event_loop().run_in_executor(
89
+ executor,
90
+ translate_text,
91
+ request.text,
92
+ request.source_lang,
93
+ request.target_lang
94
+ )
95
+
96
+ # Cache the result
97
+ redis_client.setex(cache_key, CACHE_TTL, translated_text)
98
+
99
+ return TranslationResponse(
100
+ translated_text=translated_text,
101
+ from_cache=False,
102
+ character_count=len(request.text)
103
+ )
104
+
105
+ except Exception as e:
106
+ raise HTTPException(status_code=500, detail=f"Translation error: {str(e)}")
107
+
108
+ @app.get("/health")
109
+ async def health_check():
110
+ return {"status": "healthy", "device": DEVICE}
111
+
112
+ if __name__ == "__main__":
113
+ import uvicorn
114
+ uvicorn.run(app, host="0.0.0.0", port=8000)