danicor commited on
Commit
0e92f6e
·
verified ·
1 Parent(s): 2b18932

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -425
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # translator_server_with_progress.py
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import time
@@ -9,27 +8,17 @@ from datetime import datetime, timedelta
9
  import threading
10
  from queue import Queue
11
  import logging
12
- from typing import Dict, List, Tuple, Optional, Any
13
- from fastapi import FastAPI, HTTPException, Request, BackgroundTasks
14
  from fastapi.middleware.cors import CORSMiddleware
15
- from fastapi.responses import StreamingResponse, JSONResponse
16
  from pydantic import BaseModel
17
  import uvicorn
18
- import uuid
19
- import asyncio
20
 
21
- # ------------------------
22
- # Logging setup
23
- # ------------------------
24
- logging.basicConfig(
25
- level=logging.INFO,
26
- format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
27
- )
28
- logger = logging.getLogger("translator_app")
29
 
30
- # ------------------------
31
- # Pydantic models
32
- # ------------------------
33
  class TranslationRequest(BaseModel):
34
  text: str
35
  source_lang: str
@@ -45,67 +34,6 @@ class TranslationResponse(BaseModel):
45
  status: str
46
  chunks_processed: Optional[int] = None
47
 
48
- # ------------------------
49
- # Job / Progress management
50
- # ------------------------
51
- class JobStore:
52
- """Thread-safe in-memory job store for tracking translation progress and results."""
53
- def __init__(self):
54
- self._store: Dict[str, Dict[str, Any]] = {}
55
- self._lock = threading.Lock()
56
-
57
- def create_job(self, text: str, source_lang: str, target_lang: str) -> str:
58
- job_id = uuid.uuid4().hex
59
- with self._lock:
60
- self._store[job_id] = {
61
- "job_id": job_id,
62
- "status": "queued", # queued, running, success, failed, cancelled
63
- "progress": 0.0, # percent 0.0 - 100.0
64
- "chunks_total": None,
65
- "chunks_processed": 0,
66
- "start_time": None,
67
- "last_update": None,
68
- "eta_seconds": None,
69
- "message": "Job created",
70
- "source_lang": source_lang,
71
- "target_lang": target_lang,
72
- "character_count": len(text),
73
- "result": None,
74
- "error": None
75
- }
76
- logger.info(f"Created job {job_id[:8]}... (chars={len(text)})")
77
- return job_id
78
-
79
- def update(self, job_id: str, **kwargs):
80
- with self._lock:
81
- if job_id not in self._store:
82
- logger.warning(f"Attempt to update unknown job {job_id}")
83
- return
84
- self._store[job_id].update(kwargs)
85
- self._store[job_id]["last_update"] = datetime.utcnow()
86
- # log a concise message for visibility
87
- logger.info(f"Job {job_id[:8]}... update: status={self._store[job_id]['status']} progress={self._store[job_id]['progress']:.1f}% message={self._store[job_id]['message']}")
88
-
89
- def get(self, job_id: str) -> Optional[Dict[str, Any]]:
90
- with self._lock:
91
- return dict(self._store[job_id]) if job_id in self._store else None
92
-
93
- def set_result(self, job_id: str, result: str, status: str = "success", error: Optional[str] = None):
94
- with self._lock:
95
- if job_id not in self._store:
96
- return
97
- self._store[job_id]["result"] = result
98
- self._store[job_id]["status"] = status
99
- self._store[job_id]["error"] = error
100
- self._store[job_id]["progress"] = 100.0 if status == "success" else self._store[job_id]["progress"]
101
- self._store[job_id]["last_update"] = datetime.utcnow()
102
- logger.info(f"Job {job_id[:8]}... finished with status={status} error={error}")
103
-
104
- job_store = JobStore()
105
-
106
- # ------------------------
107
- # Cache (unchanged logic but thread-safe)
108
- # ------------------------
109
  class TranslationCache:
110
  def __init__(self, cache_duration_minutes: int = 60):
111
  self.cache = {}
@@ -113,10 +41,12 @@ class TranslationCache:
113
  self.lock = threading.Lock()
114
 
115
  def _generate_key(self, text: str, source_lang: str, target_lang: str) -> str:
 
116
  content = f"{text}_{source_lang}_{target_lang}"
117
  return hashlib.md5(content.encode()).hexdigest()
118
 
119
- def get(self, text: str, source_lang: str, target_lang: str) -> Optional[str]:
 
120
  with self.lock:
121
  key = self._generate_key(text, source_lang, target_lang)
122
  if key in self.cache:
@@ -125,18 +55,17 @@ class TranslationCache:
125
  logger.info(f"Cache hit for key: {key[:8]}...")
126
  return translation
127
  else:
 
128
  del self.cache[key]
129
  return None
130
 
131
  def set(self, text: str, source_lang: str, target_lang: str, translation: str):
 
132
  with self.lock:
133
  key = self._generate_key(text, source_lang, target_lang)
134
  self.cache[key] = (translation, datetime.now())
135
  logger.info(f"Cached translation for key: {key[:8]}...")
136
 
137
- # ------------------------
138
- # Queue for background tasks (keeps existing behavior)
139
- # ------------------------
140
  class TranslationQueue:
141
  def __init__(self, max_workers: int = 3):
142
  self.queue = Queue()
@@ -145,9 +74,11 @@ class TranslationQueue:
145
  self.lock = threading.Lock()
146
 
147
  def add_task(self, task_func, *args, **kwargs):
 
148
  self.queue.put((task_func, args, kwargs))
149
 
150
  def process_queue(self):
 
151
  while not self.queue.empty():
152
  with self.lock:
153
  if self.current_workers >= self.max_workers:
@@ -160,35 +91,43 @@ class TranslationQueue:
160
 
161
  def worker():
162
  try:
163
- task_func(*args, **kwargs)
 
164
  finally:
165
  with self.lock:
166
  self.current_workers -= 1
167
 
168
- thread = threading.Thread(target=worker, daemon=True)
169
  thread.start()
170
 
171
- translation_queue = TranslationQueue(max_workers=3)
172
-
173
- # ------------------------
174
- # Text chunker (unchanged)
175
- # ------------------------
176
  class TextChunker:
 
 
177
  @staticmethod
178
  def split_text_smart(text: str, max_chunk_size: int = 400) -> List[str]:
 
179
  if len(text) <= max_chunk_size:
180
  return [text]
 
181
  chunks = []
 
 
182
  paragraphs = text.split('\n\n')
183
  current_chunk = ""
 
184
  for paragraph in paragraphs:
 
185
  if len(paragraph) > max_chunk_size:
 
186
  if current_chunk.strip():
187
  chunks.append(current_chunk.strip())
188
  current_chunk = ""
 
 
189
  sub_chunks = TextChunker._split_paragraph(paragraph, max_chunk_size)
190
  chunks.extend(sub_chunks)
191
  else:
 
192
  if len(current_chunk) + len(paragraph) + 2 > max_chunk_size:
193
  if current_chunk.strip():
194
  chunks.append(current_chunk.strip())
@@ -198,24 +137,35 @@ class TextChunker:
198
  current_chunk += "\n\n" + paragraph
199
  else:
200
  current_chunk = paragraph
 
 
201
  if current_chunk.strip():
202
  chunks.append(current_chunk.strip())
 
203
  return chunks
204
 
205
  @staticmethod
206
  def _split_paragraph(paragraph: str, max_chunk_size: int) -> List[str]:
 
 
207
  sentences = re.split(r'[.!?]+\s+', paragraph)
208
  chunks = []
209
  current_chunk = ""
 
210
  for sentence in sentences:
211
  if not sentence.strip():
212
  continue
 
 
213
  if not sentence.endswith(('.', '!', '?')):
214
  sentence += '.'
 
215
  if len(sentence) > max_chunk_size:
 
216
  if current_chunk.strip():
217
  chunks.append(current_chunk.strip())
218
  current_chunk = ""
 
219
  sub_chunks = TextChunker._split_by_comma(sentence, max_chunk_size)
220
  chunks.extend(sub_chunks)
221
  else:
@@ -228,23 +178,31 @@ class TextChunker:
228
  current_chunk += " " + sentence
229
  else:
230
  current_chunk = sentence
 
231
  if current_chunk.strip():
232
  chunks.append(current_chunk.strip())
 
233
  return chunks
234
 
235
  @staticmethod
236
  def _split_by_comma(sentence: str, max_chunk_size: int) -> List[str]:
 
237
  parts = sentence.split(', ')
238
  chunks = []
239
  current_chunk = ""
 
240
  for part in parts:
241
  if len(part) > max_chunk_size:
 
242
  if current_chunk.strip():
243
  chunks.append(current_chunk.strip())
244
  current_chunk = ""
 
 
245
  while len(part) > max_chunk_size:
246
  chunks.append(part[:max_chunk_size].strip())
247
  part = part[max_chunk_size:].strip()
 
248
  if part:
249
  current_chunk = part
250
  else:
@@ -257,16 +215,180 @@ class TextChunker:
257
  current_chunk += ", " + part
258
  else:
259
  current_chunk = part
 
260
  if current_chunk.strip():
261
  chunks.append(current_chunk.strip())
 
262
  return chunks
263
 
264
- # ------------------------
265
- # Language map (same)
266
- # ------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  LANGUAGE_MAP = {
268
  "English": "en",
269
- "Persian (Farsi)": "fa",
270
  "Arabic": "ar",
271
  "French": "fr",
272
  "German": "de",
@@ -334,194 +456,13 @@ LANGUAGE_MAP = {
334
  "Zulu": "zu"
335
  }
336
 
337
- # ------------------------
338
- # Translator with progress callbacks
339
- # ------------------------
340
- class MultilingualTranslator:
341
- def __init__(self, cache_duration_minutes: int = 60):
342
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
343
- logger.info(f"Using device: {self.device}")
344
-
345
- self.cache = TranslationCache(cache_duration_minutes)
346
- self.queue = translation_queue
347
-
348
- # Load model
349
- self.model_name = "facebook/m2m100_1.2B"
350
- logger.info(f"Loading model: {self.model_name}")
351
- try:
352
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
353
- self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
354
- self.model.to(self.device)
355
- logger.info("Model loaded successfully!")
356
- except Exception as e:
357
- logger.error(f"Error loading model: {e}")
358
- raise
359
-
360
- self.max_chunk_size = 350
361
- self.min_chunk_overlap = 20
362
-
363
- def translate_chunk(self, text: str, source_lang: str, target_lang: str) -> str:
364
- try:
365
- # set tokenizer src lang if model requires
366
- # Some m2m tokenizers require src_lang attribute
367
- try:
368
- self.tokenizer.src_lang = source_lang
369
- except Exception:
370
- pass
371
-
372
- encoded = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
373
- generated_tokens = self.model.generate(
374
- **encoded,
375
- forced_bos_token_id=self.tokenizer.get_lang_id(target_lang) if hasattr(self.tokenizer, "get_lang_id") else None,
376
- max_length=1024,
377
- min_length=10,
378
- num_beams=5,
379
- early_stopping=True,
380
- no_repeat_ngram_size=3,
381
- length_penalty=1.0,
382
- repetition_penalty=1.2,
383
- do_sample=False,
384
- pad_token_id=self.tokenizer.pad_token_id,
385
- eos_token_id=self.tokenizer.eos_token_id
386
- )
387
- translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
388
- return translation.strip()
389
- except Exception as e:
390
- logger.exception("Chunk translation error")
391
- return f"[Translation Error: {str(e)}]"
392
-
393
- def translate_text(self, text: str, source_lang: str, target_lang: str, job_id: Optional[str] = None) -> Tuple[str, float, int]:
394
- """
395
- Translate text. If job_id is provided, update job_store with progress.
396
- Returns (translation, processing_time, chunks_count)
397
- """
398
- start_time = time.time()
399
- if job_id:
400
- job_store.update(job_id, status="running", message="Starting translation", start_time=datetime.utcnow())
401
-
402
- # whole-text cache
403
- cached_result = self.cache.get(text, source_lang, target_lang)
404
- if cached_result:
405
- processing_time = time.time() - start_time
406
- if job_id:
407
- job_store.set_result(job_id, cached_result, status="success")
408
- job_store.update(job_id, progress=100.0, chunks_processed=1, chunks_total=1, message="Cache hit - completed", eta_seconds=0)
409
- logger.info(f"Cache returned result in {processing_time:.2f}s")
410
- return cached_result, processing_time, 1
411
-
412
- try:
413
- if len(text) <= self.max_chunk_size:
414
- # single chunk fast path
415
- if job_id:
416
- job_store.update(job_id, chunks_total=1, chunks_processed=0, message="Translating single chunk")
417
- translation = self.translate_chunk(text, source_lang, target_lang)
418
- self.cache.set(text, source_lang, target_lang, translation)
419
- processing_time = time.time() - start_time
420
- if job_id:
421
- job_store.set_result(job_id, translation, status="success")
422
- job_store.update(job_id, progress=100.0, chunks_processed=1, chunks_total=1, message="Completed", eta_seconds=0)
423
- logger.info(f"Short text translation completed in {processing_time:.2f} seconds")
424
- return translation, processing_time, 1
425
-
426
- # long text -> chunking
427
- chunks = TextChunker.split_text_smart(text, self.max_chunk_size)
428
- total_chunks = len(chunks)
429
- if job_id:
430
- job_store.update(job_id, chunks_total=total_chunks, chunks_processed=0, progress=0.0, message=f"Split into {total_chunks} chunks")
431
- logger.info(f"Split long text into {total_chunks} chunks")
432
-
433
- translated_chunks = []
434
- chunk_times: List[float] = []
435
- for i, chunk in enumerate(chunks):
436
- chunk_start = time.time()
437
- logger.info(f"Translating chunk {i+1}/{total_chunks} length={len(chunk)}")
438
- if job_id:
439
- job_store.update(job_id, message=f"Translating chunk {i+1}/{total_chunks}")
440
-
441
- # check per-chunk cache
442
- chunk_cached = self.cache.get(chunk, source_lang, target_lang)
443
- if chunk_cached:
444
- ct = chunk_cached
445
- logger.info(f"Chunk {i+1} cache hit")
446
- else:
447
- ct = self.translate_chunk(chunk, source_lang, target_lang)
448
- self.cache.set(chunk, source_lang, target_lang, ct)
449
-
450
- translated_chunks.append(ct)
451
- chunk_elapsed = time.time() - chunk_start
452
- chunk_times.append(chunk_elapsed)
453
-
454
- # update progress
455
- processed = i + 1
456
- avg = sum(chunk_times) / len(chunk_times) if chunk_times else 0.0
457
- remaining = max(0, total_chunks - processed)
458
- eta = avg * remaining
459
- progress_percent = (processed / total_chunks) * 100.0
460
-
461
- if job_id:
462
- job_store.update(job_id,
463
- chunks_processed=processed,
464
- progress=round(progress_percent, 2),
465
- eta_seconds=round(eta, 1),
466
- message=f"Processed {processed}/{total_chunks} chunks (avg_chunk={avg:.2f}s)")
467
-
468
- # small throttle to be kind to device
469
- if i < total_chunks - 1:
470
- time.sleep(0.05)
471
-
472
- # combine
473
- final_translation = self._combine_translations(translated_chunks, text)
474
- self.cache.set(text, source_lang, target_lang, final_translation)
475
- processing_time = time.time() - start_time
476
-
477
- if job_id:
478
- job_store.set_result(job_id, final_translation, status="success")
479
- job_store.update(job_id, progress=100.0, chunks_processed=total_chunks, chunks_total=total_chunks,
480
- message=f"Completed in {processing_time:.2f}s", eta_seconds=0)
481
-
482
- logger.info(f"Long text translation completed in {processing_time:.2f} seconds ({total_chunks} chunks)")
483
- return final_translation, processing_time, total_chunks
484
-
485
- except Exception as e:
486
- logger.exception("Translation error")
487
- processing_time = time.time() - start_time
488
- if job_id:
489
- job_store.set_result(job_id, "", status="failed", error=str(e))
490
- job_store.update(job_id, progress=0.0, message=f"Failed: {str(e)}")
491
- return f"Translation error: {str(e)}", processing_time, 0
492
-
493
- def _combine_translations(self, translated_chunks: List[str], original_text: str) -> str:
494
- if not translated_chunks:
495
- return ""
496
- if len(translated_chunks) == 1:
497
- return translated_chunks[0]
498
- combined = []
499
- for i, chunk in enumerate(translated_chunks):
500
- chunk = chunk.strip()
501
- if not chunk:
502
- continue
503
- if i > 0 and combined:
504
- if not combined[-1].rstrip().endswith(('.', '!', '?', ':', '؛', '.')):
505
- combined[-1] += '.'
506
- if '\n\n' in original_text:
507
- combined.append('\n\n' + chunk)
508
- else:
509
- combined.append(' ' + chunk)
510
- else:
511
- combined.append(chunk)
512
- result = ''.join(combined)
513
- result = re.sub(r'\s+', ' ', result)
514
- result = re.sub(r'\.+', '.', result)
515
- return result.strip()
516
-
517
- # initialize translator (loads model) - this can take time at startup
518
  translator = MultilingualTranslator(60)
519
 
520
- # ------------------------
521
- # FastAPI app
522
- # ------------------------
523
- app = FastAPI(title="Multilingual Translation API with Progress", version="2.0.0")
524
 
 
525
  app.add_middleware(
526
  CORSMiddleware,
527
  allow_origins=["*"],
@@ -532,169 +473,92 @@ app.add_middleware(
532
 
533
  @app.get("/")
534
  async def root():
535
- return {"message": "Multilingual Translation API v2.0 (with progress)", "status": "active", "features": ["long_text_support", "smart_chunking", "cache_optimization", "progress_tracking", "sse"]}
536
 
537
- # Synchronous translate endpoint (keeps previous behavior but logs progress and updates job_store)
538
  @app.post("/api/translate")
539
  async def api_translate(request: TranslationRequest):
 
540
  if not request.text.strip():
541
  raise HTTPException(status_code=400, detail="No text provided")
542
-
543
  source_code = LANGUAGE_MAP.get(request.source_lang)
544
  target_code = LANGUAGE_MAP.get(request.target_lang)
545
-
546
  if not source_code or not target_code:
547
  raise HTTPException(status_code=400, detail="Invalid language codes")
548
-
549
- # create a job so consumer can check progress even for sync call
550
- job_id = job_store.create_job(request.text, request.source_lang, request.target_lang)
551
- job_store.update(job_id, message="Synchronous translation requested")
552
-
553
- # Run translation in a separate thread but wait (so endpoint remains sync from client's POV)
554
- result_container = {"translation": None, "time": None, "chunks": None, "error": None}
555
- def do_translate():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
  try:
557
- translation, processing_time, chunks_count = translator.translate_text(request.text, source_code, target_code, job_id=job_id)
558
- result_container["translation"] = translation
559
- result_container["time"] = processing_time
560
- result_container["chunks"] = chunks_count
561
- except Exception as e:
562
- result_container["error"] = str(e)
563
- thread = threading.Thread(target=do_translate)
564
- thread.start()
565
- # Wait for thread (because this endpoint is synchronous in original)
566
- thread.join()
567
-
568
- if result_container["error"]:
569
- raise HTTPException(status_code=500, detail=f"Translation error: {result_container['error']}")
570
-
571
- return TranslationResponse(
572
- translation=result_container["translation"],
573
- source_language=request.source_lang,
574
- target_language=request.target_lang,
575
- processing_time=result_container["time"],
576
- character_count=len(request.text),
577
- status="success",
578
- chunks_processed=result_container["chunks"]
579
- )
580
-
581
- # Async background endpoint: returns job_id immediately and does work in background
582
- @app.post("/api/translate_async")
583
- async def api_translate_async(request: TranslationRequest, background_tasks: BackgroundTasks):
584
- if not request.text.strip():
585
  raise HTTPException(status_code=400, detail="No text provided")
586
- source_code = LANGUAGE_MAP.get(request.source_lang)
587
- target_code = LANGUAGE_MAP.get(request.target_lang)
 
 
588
  if not source_code or not target_code:
589
  raise HTTPException(status_code=400, detail="Invalid language codes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
- job_id = job_store.create_job(request.text, request.source_lang, request.target_lang)
592
-
593
- def background_work(text, s_code, t_code, jobid):
594
- try:
595
- translator.translate_text(text, s_code, t_code, job_id=jobid)
596
- except Exception as e:
597
- logger.exception("Background translation failed")
598
- job_store.set_result(jobid, "", status="failed", error=str(e))
599
- job_store.update(jobid, message="Background task failed")
600
-
601
- background_tasks.add_task(background_work, request.text, source_code, target_code, job_id)
602
- return {"job_id": job_id, "status": "accepted", "message": "Translation started in background. Use /api/job/{job_id} or /api/stream/{job_id} to monitor progress."}
603
-
604
- # Job status endpoint
605
- @app.get("/api/job/{job_id}")
606
- async def get_job_status(job_id: str):
607
- job = job_store.get(job_id)
608
- if not job:
609
- raise HTTPException(status_code=404, detail="Job not found")
610
- # Return a subset that is safe to expose
611
- safe = {
612
- "job_id": job["job_id"],
613
- "status": job["status"],
614
- "progress": job["progress"],
615
- "chunks_total": job["chunks_total"],
616
- "chunks_processed": job["chunks_processed"],
617
- "eta_seconds": job["eta_seconds"],
618
- "message": job["message"],
619
- "source_lang": job["source_lang"],
620
- "target_lang": job["target_lang"],
621
- "character_count": job["character_count"],
622
- "error": job["error"]
623
- }
624
- if job["result"] is not None and job["status"] == "success":
625
- safe["translation_available"] = True
626
- else:
627
- safe["translation_available"] = False
628
- return safe
629
-
630
- # SSE stream for live updates (client can connect with EventSource)
631
- @app.get("/api/stream/{job_id}")
632
- async def stream_job_progress(job_id: str):
633
- job = job_store.get(job_id)
634
- if not job:
635
- raise HTTPException(status_code=404, detail="Job not found")
636
-
637
- async def event_generator():
638
- logger.info(f"SSE client connected for job {job_id[:8]}...")
639
- last_snapshot = None
640
- while True:
641
- job_snapshot = job_store.get(job_id)
642
- if job_snapshot is None:
643
- # job disappeared
644
- yield f"event: error\ndata: {json.dumps({'message': 'job not found'})}\n\n"
645
- break
646
-
647
- # send update only if changed
648
- if job_snapshot != last_snapshot:
649
- payload = {
650
- "job_id": job_snapshot["job_id"],
651
- "status": job_snapshot["status"],
652
- "progress": job_snapshot["progress"],
653
- "chunks_total": job_snapshot["chunks_total"],
654
- "chunks_processed": job_snapshot["chunks_processed"],
655
- "eta_seconds": job_snapshot["eta_seconds"],
656
- "message": job_snapshot["message"],
657
- "source_lang": job_snapshot["source_lang"],
658
- "target_lang": job_snapshot["target_lang"],
659
- "character_count": job_snapshot["character_count"],
660
- "error": job_snapshot["error"],
661
- }
662
- # if completed and success, include small result preview (not full text to avoid huge SSE)
663
- if job_snapshot["status"] in ("success", "failed") and job_snapshot["result"] is not None:
664
- payload["result_preview"] = job_snapshot["result"][:1000] # first 1k chars
665
- data = json.dumps(payload, default=str)
666
- yield f"data: {data}\n\n"
667
- last_snapshot = job_snapshot
668
-
669
- # stop if finished
670
- if job_snapshot["status"] in ("success", "failed", "cancelled"):
671
- logger.info(f"SSE: job {job_id[:8]} finished with status {job_snapshot['status']}")
672
- break
673
-
674
- await asyncio.sleep(0.5) # poll interval
675
-
676
- # final close message
677
- yield f"event: close\ndata: {json.dumps({'message': 'stream closed'})}\n\n"
678
-
679
- return StreamingResponse(event_generator(), media_type="text/event-stream")
680
-
681
- # endpoint to fetch final translation (if ready)
682
- @app.get("/api/result/{job_id}")
683
- async def get_result(job_id: str):
684
- job = job_store.get(job_id)
685
- if not job:
686
- raise HTTPException(status_code=404, detail="Job not found")
687
- if job["status"] != "success":
688
- return JSONResponse(status_code=202, content={"status": job["status"], "message": "Result not ready"})
689
- return {"job_id": job_id, "translation": job["result"], "character_count": job["character_count"]}
690
-
691
- # languages and health (preserve)
692
  @app.get("/api/languages")
693
  async def get_languages():
694
- return {"languages": list(LANGUAGE_MAP.keys()), "language_codes": LANGUAGE_MAP, "status": "success"}
 
 
 
 
 
695
 
696
  @app.get("/api/health")
697
  async def health_check():
 
698
  return {
699
  "status": "healthy",
700
  "device": str(translator.device),
@@ -704,7 +568,5 @@ async def health_check():
704
  "version": "2.0.0"
705
  }
706
 
707
- # Run
708
  if __name__ == "__main__":
709
- # IMPORTANT: for production, use uvicorn/gunicorn with workers and proper GPU visibility
710
- uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import time
 
8
  import threading
9
  from queue import Queue
10
  import logging
11
+ from typing import Dict, List, Tuple, Optional
12
+ from fastapi import FastAPI, HTTPException, Request
13
  from fastapi.middleware.cors import CORSMiddleware
 
14
  from pydantic import BaseModel
15
  import uvicorn
 
 
16
 
17
+ # Set up logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
 
 
 
 
 
20
 
21
+ # Pydantic models for request/response
 
 
22
  class TranslationRequest(BaseModel):
23
  text: str
24
  source_lang: str
 
34
  status: str
35
  chunks_processed: Optional[int] = None
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  class TranslationCache:
38
  def __init__(self, cache_duration_minutes: int = 60):
39
  self.cache = {}
 
41
  self.lock = threading.Lock()
42
 
43
  def _generate_key(self, text: str, source_lang: str, target_lang: str) -> str:
44
+ """Generate cache key from text and languages"""
45
  content = f"{text}_{source_lang}_{target_lang}"
46
  return hashlib.md5(content.encode()).hexdigest()
47
 
48
+ def get(self, text: str, source_lang: str, target_lang: str) -> str:
49
+ """Get translation from cache if exists and not expired"""
50
  with self.lock:
51
  key = self._generate_key(text, source_lang, target_lang)
52
  if key in self.cache:
 
55
  logger.info(f"Cache hit for key: {key[:8]}...")
56
  return translation
57
  else:
58
+ # Remove expired entry
59
  del self.cache[key]
60
  return None
61
 
62
  def set(self, text: str, source_lang: str, target_lang: str, translation: str):
63
+ """Store translation in cache"""
64
  with self.lock:
65
  key = self._generate_key(text, source_lang, target_lang)
66
  self.cache[key] = (translation, datetime.now())
67
  logger.info(f"Cached translation for key: {key[:8]}...")
68
 
 
 
 
69
  class TranslationQueue:
70
  def __init__(self, max_workers: int = 3):
71
  self.queue = Queue()
 
74
  self.lock = threading.Lock()
75
 
76
  def add_task(self, task_func, *args, **kwargs):
77
+ """Add translation task to queue"""
78
  self.queue.put((task_func, args, kwargs))
79
 
80
  def process_queue(self):
81
+ """Process tasks from queue"""
82
  while not self.queue.empty():
83
  with self.lock:
84
  if self.current_workers >= self.max_workers:
 
91
 
92
  def worker():
93
  try:
94
+ result = task_func(*args, **kwargs)
95
+ return result
96
  finally:
97
  with self.lock:
98
  self.current_workers -= 1
99
 
100
+ thread = threading.Thread(target=worker)
101
  thread.start()
102
 
 
 
 
 
 
103
  class TextChunker:
104
+ """کلاس برای تقسیم متن طولانی به بخش‌های کوچکتر"""
105
+
106
  @staticmethod
107
  def split_text_smart(text: str, max_chunk_size: int = 400) -> List[str]:
108
+ """تقسیم هوشمند متن بر اساس جملات و پاراگراف‌ها"""
109
  if len(text) <= max_chunk_size:
110
  return [text]
111
+
112
  chunks = []
113
+
114
+ # تقسیم بر اساس پاراگراف‌ها
115
  paragraphs = text.split('\n\n')
116
  current_chunk = ""
117
+
118
  for paragraph in paragraphs:
119
+ # اگر پاراگراف خودش بزرگ است، آن را تقسیم کن
120
  if len(paragraph) > max_chunk_size:
121
+ # ذخیره قسمت فعلی اگر وجود دارد
122
  if current_chunk.strip():
123
  chunks.append(current_chunk.strip())
124
  current_chunk = ""
125
+
126
+ # تقسیم پاراگراف بزرگ
127
  sub_chunks = TextChunker._split_paragraph(paragraph, max_chunk_size)
128
  chunks.extend(sub_chunks)
129
  else:
130
+ # بررسی اینکه آیا اضافه کردن این پاراگراف از حد تجاوز می‌کند
131
  if len(current_chunk) + len(paragraph) + 2 > max_chunk_size:
132
  if current_chunk.strip():
133
  chunks.append(current_chunk.strip())
 
137
  current_chunk += "\n\n" + paragraph
138
  else:
139
  current_chunk = paragraph
140
+
141
+ # اضافه کردن آخرین قسمت
142
  if current_chunk.strip():
143
  chunks.append(current_chunk.strip())
144
+
145
  return chunks
146
 
147
  @staticmethod
148
  def _split_paragraph(paragraph: str, max_chunk_size: int) -> List[str]:
149
+ """تقسیم پاراگراف بزرگ به جملات"""
150
+ # تقسیم بر اساس جملات
151
  sentences = re.split(r'[.!?]+\s+', paragraph)
152
  chunks = []
153
  current_chunk = ""
154
+
155
  for sentence in sentences:
156
  if not sentence.strip():
157
  continue
158
+
159
+ # اضافه کردن علامت نقطه اگر حذف شده
160
  if not sentence.endswith(('.', '!', '?')):
161
  sentence += '.'
162
+
163
  if len(sentence) > max_chunk_size:
164
+ # جمله خودش خیلی بلند است - تقسیم بر اساس کاما
165
  if current_chunk.strip():
166
  chunks.append(current_chunk.strip())
167
  current_chunk = ""
168
+
169
  sub_chunks = TextChunker._split_by_comma(sentence, max_chunk_size)
170
  chunks.extend(sub_chunks)
171
  else:
 
178
  current_chunk += " " + sentence
179
  else:
180
  current_chunk = sentence
181
+
182
  if current_chunk.strip():
183
  chunks.append(current_chunk.strip())
184
+
185
  return chunks
186
 
187
  @staticmethod
188
  def _split_by_comma(sentence: str, max_chunk_size: int) -> List[str]:
189
+ """تقسیم جمله طولانی بر اساس کاما"""
190
  parts = sentence.split(', ')
191
  chunks = []
192
  current_chunk = ""
193
+
194
  for part in parts:
195
  if len(part) > max_chunk_size:
196
+ # قسمت خودش خیلی بلند است - تقسیم اجباری
197
  if current_chunk.strip():
198
  chunks.append(current_chunk.strip())
199
  current_chunk = ""
200
+
201
+ # تقسیم اجباری بر اساس طول
202
  while len(part) > max_chunk_size:
203
  chunks.append(part[:max_chunk_size].strip())
204
  part = part[max_chunk_size:].strip()
205
+
206
  if part:
207
  current_chunk = part
208
  else:
 
215
  current_chunk += ", " + part
216
  else:
217
  current_chunk = part
218
+
219
  if current_chunk.strip():
220
  chunks.append(current_chunk.strip())
221
+
222
  return chunks
223
 
224
+ class MultilingualTranslator:
225
+ def __init__(self, cache_duration_minutes: int = 60):
226
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
227
+ logger.info(f"Using device: {self.device}")
228
+
229
+ # Initialize cache and queue
230
+ self.cache = TranslationCache(cache_duration_minutes)
231
+ self.queue = TranslationQueue()
232
+
233
+ # Load model - using a powerful multilingual model
234
+ self.model_name = "facebook/m2m100_1.2B"
235
+ logger.info(f"Loading model: {self.model_name}")
236
+
237
+ try:
238
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
239
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
240
+ self.model.to(self.device)
241
+ logger.info("Model loaded successfully!")
242
+ except Exception as e:
243
+ logger.error(f"Error loading model: {e}")
244
+ raise
245
+
246
+ # تنظیمات بهینه برای ترجمه متن‌های بلند
247
+ self.max_chunk_size = 350 # حداکثر طول هر قسمت
248
+ self.min_chunk_overlap = 20 # همپوشانی بین قسمت‌ها
249
+
250
+ def translate_chunk(self, text: str, source_lang: str, target_lang: str) -> str:
251
+ """ترجمه یک قسمت کوچک از متن"""
252
+ try:
253
+ # Set source language for tokenizer
254
+ self.tokenizer.src_lang = source_lang
255
+
256
+ # Encode input
257
+ encoded = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
258
+
259
+ # Generate translation with optimized parameters
260
+ generated_tokens = self.model.generate(
261
+ **encoded,
262
+ forced_bos_token_id=self.tokenizer.get_lang_id(target_lang),
263
+ max_length=1024, # افزایش طول خروجی
264
+ min_length=10, # حداقل طول خروجی
265
+ num_beams=5, # افزایش تعداد beam ها برای کیفیت بهتر
266
+ early_stopping=True,
267
+ no_repeat_ngram_size=3, # جلوگیری از تکرار
268
+ length_penalty=1.0, # تنظیم جریمه طول
269
+ repetition_penalty=1.2, # جلوگیری از تکرار کلمات
270
+ do_sample=False, # استفاده از روش قطعی
271
+ temperature=0.7, # کنترل تنوع
272
+ pad_token_id=self.tokenizer.pad_token_id,
273
+ eos_token_id=self.tokenizer.eos_token_id
274
+ )
275
+
276
+ # Decode result
277
+ translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
278
+
279
+ # پاک‌سازی ترجمه از کاراکترهای اضافی
280
+ translation = translation.strip()
281
+
282
+ return translation
283
+
284
+ except Exception as e:
285
+ logger.error(f"Chunk translation error: {e}")
286
+ return f"[Translation Error: {str(e)}]"
287
+
288
+ def translate_text(self, text: str, source_lang: str, target_lang: str) -> Tuple[str, float, int]:
289
+ """ترجمه متن با پشتیبانی از متن‌های طولانی"""
290
+ start_time = time.time()
291
+
292
+ # بررسی کش برای کل متن
293
+ cached_result = self.cache.get(text, source_lang, target_lang)
294
+ if cached_result:
295
+ return cached_result, time.time() - start_time, 1
296
+
297
+ try:
298
+ # اگر متن کوتاه است، مستقیماً ترجمه کن
299
+ if len(text) <= self.max_chunk_size:
300
+ translation = self.translate_chunk(text, source_lang, target_lang)
301
+
302
+ # ذخیره در کش
303
+ self.cache.set(text, source_lang, target_lang, translation)
304
+ processing_time = time.time() - start_time
305
+ logger.info(f"Short text translation completed in {processing_time:.2f} seconds")
306
+
307
+ return translation, processing_time, 1
308
+
309
+ # تقسیم متن طولانی به قسمت‌های کوچکتر
310
+ chunks = TextChunker.split_text_smart(text, self.max_chunk_size)
311
+ logger.info(f"Split long text into {len(chunks)} chunks")
312
+
313
+ # ترجمه هر قسمت
314
+ translated_chunks = []
315
+ for i, chunk in enumerate(chunks):
316
+ logger.info(f"Translating chunk {i+1}/{len(chunks)} (length: {len(chunk)})")
317
+
318
+ # بررسی کش برای هر قسمت
319
+ chunk_translation = self.cache.get(chunk, source_lang, target_lang)
320
+
321
+ if not chunk_translation:
322
+ chunk_translation = self.translate_chunk(chunk, source_lang, target_lang)
323
+ # ذخیره قسمت در کش
324
+ self.cache.set(chunk, source_lang, target_lang, chunk_translation)
325
+
326
+ translated_chunks.append(chunk_translation)
327
+
328
+ # کمی استراحت بین ترجمه‌ها برای جلوگیری از بارگذاری زیاد
329
+ if i < len(chunks) - 1:
330
+ time.sleep(0.1)
331
+
332
+ # ترکیب قسمت‌های ترجمه شده
333
+ final_translation = self._combine_translations(translated_chunks, text)
334
+
335
+ # ذخیره نتیجه نهایی در کش
336
+ self.cache.set(text, source_lang, target_lang, final_translation)
337
+
338
+ processing_time = time.time() - start_time
339
+ logger.info(f"Long text translation completed in {processing_time:.2f} seconds ({len(chunks)} chunks)")
340
+
341
+ return final_translation, processing_time, len(chunks)
342
+
343
+ except Exception as e:
344
+ logger.error(f"Translation error: {e}")
345
+ return f"Translation error: {str(e)}", time.time() - start_time, 0
346
+
347
+ def _combine_translations(self, translated_chunks: List[str], original_text: str) -> str:
348
+ """ترکیب قسمت‌های ترجمه شده به یک متن یکپارچه"""
349
+ if not translated_chunks:
350
+ return ""
351
+
352
+ if len(translated_chunks) == 1:
353
+ return translated_chunks[0]
354
+
355
+ # ترکیب قسمت‌ها با در نظر گیری ساختار اصلی متن
356
+ combined = []
357
+
358
+ for i, chunk in enumerate(translated_chunks):
359
+ # پاک‌سازی قسمت
360
+ chunk = chunk.strip()
361
+
362
+ if not chunk:
363
+ continue
364
+
365
+ # اضافه کردن فاصله مناسب بین قسمت‌ها
366
+ if i > 0 and combined:
367
+ # اگر قسمت قبلی با نقطه تمام نمی‌شود، نقطه اضافه کن
368
+ if not combined[-1].rstrip().endswith(('.', '!', '?', ':', '؛', '.')):
369
+ combined[-1] += '.'
370
+
371
+ # بررسی اینکه آیا نیاز به پاراگراف جدید داریم
372
+ if '\n\n' in original_text:
373
+ combined.append('\n\n' + chunk)
374
+ else:
375
+ combined.append(' ' + chunk)
376
+ else:
377
+ combined.append(chunk)
378
+
379
+ result = ''.join(combined)
380
+
381
+ # پاک‌سازی نهایی
382
+ result = re.sub(r'\s+', ' ', result) # حذف فاصله‌های اضافی
383
+ result = re.sub(r'\.+', '.', result) # حذف نقطه‌های تکراری
384
+ result = result.strip()
385
+
386
+ return result
387
+
388
+ # Language mappings for M2M100 model
389
  LANGUAGE_MAP = {
390
  "English": "en",
391
+ "Persian (Farsi)": "fa",
392
  "Arabic": "ar",
393
  "French": "fr",
394
  "German": "de",
 
456
  "Zulu": "zu"
457
  }
458
 
459
+ # Initialize translator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  translator = MultilingualTranslator(60)
461
 
462
+ # Create FastAPI app
463
+ app = FastAPI(title="Multilingual Translation API", version="2.0.0")
 
 
464
 
465
+ # Add CORS middleware
466
  app.add_middleware(
467
  CORSMiddleware,
468
  allow_origins=["*"],
 
473
 
474
  @app.get("/")
475
  async def root():
476
+ return {"message": "Multilingual Translation API v2.0", "status": "active", "features": ["long_text_support", "smart_chunking", "cache_optimization"]}
477
 
 
478
  @app.post("/api/translate")
479
  async def api_translate(request: TranslationRequest):
480
+ """API endpoint for translation with long text support"""
481
  if not request.text.strip():
482
  raise HTTPException(status_code=400, detail="No text provided")
483
+
484
  source_code = LANGUAGE_MAP.get(request.source_lang)
485
  target_code = LANGUAGE_MAP.get(request.target_lang)
486
+
487
  if not source_code or not target_code:
488
  raise HTTPException(status_code=400, detail="Invalid language codes")
489
+
490
+ try:
491
+ translation, processing_time, chunks_count = translator.translate_text(request.text, source_code, target_code)
492
+
493
+ return TranslationResponse(
494
+ translation=translation,
495
+ source_language=request.source_lang,
496
+ target_language=request.target_lang,
497
+ processing_time=processing_time,
498
+ character_count=len(request.text),
499
+ status="success",
500
+ chunks_processed=chunks_count
501
+ )
502
+ except Exception as e:
503
+ raise HTTPException(status_code=500, detail=f"Translation error: {str(e)}")
504
+
505
+ # Alternative endpoint for form data (compatibility with WordPress)
506
+ @app.post("/api/translate/form")
507
+ async def api_translate_form(request: Request):
508
+ """Alternative endpoint that accepts form data with long text support"""
509
+ try:
510
+ form_data = await request.form()
511
+ text = form_data.get("text", "")
512
+ source_lang = form_data.get("source_lang", "")
513
+ target_lang = form_data.get("target_lang", "")
514
+ api_key = form_data.get("api_key", None)
515
+ except:
516
  try:
517
+ # Try to get JSON data if form data fails
518
+ json_data = await request.json()
519
+ text = json_data.get("text", "")
520
+ source_lang = json_data.get("source_lang", "")
521
+ target_lang = json_data.get("target_lang", "")
522
+ api_key = json_data.get("api_key", None)
523
+ except:
524
+ raise HTTPException(status_code=400, detail="Invalid request format")
525
+
526
+ if not text.strip():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  raise HTTPException(status_code=400, detail="No text provided")
528
+
529
+ source_code = LANGUAGE_MAP.get(source_lang)
530
+ target_code = LANGUAGE_MAP.get(target_lang)
531
+
532
  if not source_code or not target_code:
533
  raise HTTPException(status_code=400, detail="Invalid language codes")
534
+
535
+ try:
536
+ translation, processing_time, chunks_count = translator.translate_text(text, source_code, target_code)
537
+
538
+ return {
539
+ "translation": translation,
540
+ "source_language": source_lang,
541
+ "target_language": target_lang,
542
+ "processing_time": processing_time,
543
+ "character_count": len(text),
544
+ "status": "success",
545
+ "chunks_processed": chunks_count
546
+ }
547
+ except Exception as e:
548
+ raise HTTPException(status_code=500, detail=f"Translation error: {str(e)}")
549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  @app.get("/api/languages")
551
  async def get_languages():
552
+ """Get supported languages"""
553
+ return {
554
+ "languages": list(LANGUAGE_MAP.keys()),
555
+ "language_codes": LANGUAGE_MAP,
556
+ "status": "success"
557
+ }
558
 
559
  @app.get("/api/health")
560
  async def health_check():
561
+ """Health check endpoint"""
562
  return {
563
  "status": "healthy",
564
  "device": str(translator.device),
 
568
  "version": "2.0.0"
569
  }
570
 
 
571
  if __name__ == "__main__":
572
+ uvicorn.run(app, host="0.0.0.0", port=7860)