DarkMo0o commited on
Commit
e7cfc43
·
verified ·
1 Parent(s): b41c796

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -35
app.py CHANGED
@@ -1,53 +1,78 @@
1
  from fastapi import FastAPI, File, UploadFile, Form
2
- from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
 
 
3
 
4
  app = FastAPI()
5
 
6
- model_name = "facebook/m2m100_418M"
7
- model = M2M100ForConditionalGeneration.from_pretrained(model_name)
8
- tokenizer = M2M100Tokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @app.post("/translate-text")
11
  async def translate_text(
12
  text: str = Form(...),
13
- source_lang: str = Form(...),
14
  target_lang: str = Form(...)
15
  ):
16
- tokenizer.src_lang = source_lang
17
- encoded = tokenizer(text, return_tensors="pt")
18
- generated = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(target_lang))
19
- translated = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
20
- return {"translated_text": translated}
 
 
 
 
 
 
 
 
21
 
22
  @app.post("/translate-file")
23
  async def translate_file(
24
  file: UploadFile = File(...),
25
- source_lang: str = Form(...),
26
  target_lang: str = Form(...)
27
  ):
28
  contents = await file.read()
29
  original_text = contents.decode()
30
-
31
- # قسم النص الأصلي إلى شرائح في حدود 900 حرف تقريباً لكل شريحة (يمكنك ضبط الرقم حسب تجربة الأداء)
32
- lines = original_text.splitlines()
33
- chunks = []
34
- chunk = ""
35
- max_chunk_length = 900
36
- for line in lines:
37
- if len(chunk) + len(line) < max_chunk_length:
38
- chunk += line + "\n"
39
- else:
40
- chunks.append(chunk.strip())
41
- chunk = line + "\n"
42
- if chunk:
43
- chunks.append(chunk.strip())
44
-
45
- result = ""
46
- for chunk in chunks:
47
- tokenizer.src_lang = source_lang
48
- encoded = tokenizer(chunk, return_tensors="pt")
49
- generated = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(target_lang))
50
- translated = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
51
- result += translated + "\n"
52
-
53
- return {"translated_text": result}
 
1
  from fastapi import FastAPI, File, UploadFile, Form
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ from langdetect import detect
4
+ import re
5
 
6
  app = FastAPI()
7
 
8
+ MODEL_NAME = "facebook/nllb-200-600M"
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
+
12
+ def split_text_lines(text, max_chunk_length=900):
13
+ # تقسيم ذكي مع الحفاظ على أسطر strings
14
+ lines = text.splitlines()
15
+ chunks = []
16
+ chunk = ""
17
+ for line in lines:
18
+ if len(chunk) + len(line) < max_chunk_length:
19
+ chunk += line + "\n"
20
+ else:
21
+ if chunk.strip():
22
+ chunks.append(chunk.strip())
23
+ chunk = line + "\n"
24
+ if chunk.strip():
25
+ chunks.append(chunk.strip())
26
+ return chunks
27
+
28
+ def batch_translate(texts, src_lang, tgt_lang):
29
+ # ترجمة سريعة batch
30
+ results = []
31
+ batch_size = 8 # يمكنك زيادة العدد حسب موارد السيرفر
32
+ for i in range(0, len(texts), batch_size):
33
+ batch = texts[i:i+batch_size]
34
+ tokenizer.src_lang = src_lang
35
+ inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=1024)
36
+ generated = model.generate(**inputs, forced_bos_token_id=tokenizer.convert_lang_to_id(tgt_lang))
37
+ translated = tokenizer.batch_decode(generated, skip_special_tokens=True)
38
+ results.extend(translated)
39
+ return results
40
+
41
+ def detect_language(text):
42
+ # كشف لغة ذكي (يعمل على أول chunk)
43
+ sample = text[:2000] if len(text) > 2000 else text
44
+ lang = detect(sample)
45
+ # وفق أكواد NLLB المتوافقة (تعديل سريع)
46
+ lang_map = {"en": "eng_Latn", "ar": "arb_Arab", "fr": "fra_Latn", "hi": "hin_Deva", "es": "spa_Latn", "de": "deu_Latn"}
47
+ return lang_map.get(lang, "eng_Latn")
48
 
49
  @app.post("/translate-text")
50
  async def translate_text(
51
  text: str = Form(...),
 
52
  target_lang: str = Form(...)
53
  ):
54
+ source_lang = detect_language(text)
55
+ texts = re.split(r'(?<=[.!?\n])\s+', text.strip())
56
+ chunks = []
57
+ cur_chunk = ""
58
+ for sentence in texts:
59
+ if len(cur_chunk) + len(sentence) < 900:
60
+ cur_chunk += sentence + " "
61
+ else:
62
+ chunks.append(cur_chunk.strip())
63
+ cur_chunk = sentence + " "
64
+ if cur_chunk.strip(): chunks.append(cur_chunk.strip())
65
+ translated = batch_translate(chunks, source_lang, target_lang)
66
+ return {"translated_text": "\n".join(translated)}
67
 
68
  @app.post("/translate-file")
69
  async def translate_file(
70
  file: UploadFile = File(...),
 
71
  target_lang: str = Form(...)
72
  ):
73
  contents = await file.read()
74
  original_text = contents.decode()
75
+ source_lang = detect_language(original_text)
76
+ lines = split_text_lines(original_text)
77
+ translated_lines = batch_translate(lines, source_lang, target_lang)
78
+ return {"translated_text": "\n".join(translated_lines)}