RohanAi commited on
Commit
b946487
·
verified ·
1 Parent(s): a2a0ad4

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -69
app.py DELETED
@@ -1,69 +0,0 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BitsAndBytesConfig
2
- from sacremoses import MosesPunctNormalizer
3
- from flores import code_mapping
4
- import gradio as gr
5
- import platform
6
-
7
- device = "cpu" if platform.system() == "Darwin" else "cuda"
8
- device = "cpu" # Force CPU for compatibility
9
- MODEL_DIR = "./nllb-600M-quantized"
10
-
11
- # 8-bit quantization for GPU
12
- # bnb_config = BitsAndBytesConfig(load_in_8bit=True)
13
-
14
- # Load tokenizer + model
15
- tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
16
- if device == "cuda":
17
- pass
18
- # model = AutoModelForSeq2SeqLM.from_pretrained(
19
- # MODEL_DIR, device_map="auto", quantization_config=bnb_config
20
- # )
21
- else:
22
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
23
-
24
- punct_normalizer = MosesPunctNormalizer(lang="en")
25
-
26
- # Language mapping
27
- langs = {
28
- "Hindi": "hin_Deva",
29
- "French": "fra_Latn",
30
- "Spanish": "spa_Latn",
31
- "German": "deu_Latn",
32
- "Arabic": "arb_Arab"
33
- }
34
-
35
- def translate(text: str, src_lang: str, tgt_lang: str):
36
- src_code = code_mapping[src_lang] # e.g. "English" -> "eng_Latn"
37
- tgt_code = code_mapping[tgt_lang]
38
- print('source lang code ',src_code) # e.g. "Hindi" -> "hin_Deva"
39
-
40
- tokenizer.src_lang = src_code
41
- tokenizer.tgt_lang = tgt_code
42
-
43
- # Normalize punctuation
44
- text = punct_normalizer.normalize(text)
45
-
46
- # Encode & generate
47
- inputs = tokenizer(text, return_tensors="pt").to(device)
48
- outputs = model.generate(
49
- **inputs,
50
- forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code), # use FLORES code
51
- # max_length=len(inputs.input_ids[0]) + 150, # dynamic max length
52
- num_beams=3, # CPU-friendly greedy decoding
53
- no_repeat_ngram_size=2, # small repetition control
54
- )
55
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
56
-
57
-
58
- langs = list(code_mapping.keys())
59
-
60
- iface = gr.Interface(
61
- fn=translate,
62
- inputs=[gr.Textbox(lines=10, label="Input Text"),
63
- gr.Dropdown(langs, label="Source Language"),
64
- gr.Dropdown(langs, label="Target Language")],
65
- outputs=gr.Textbox(lines=30, label="Translated Text"),
66
- title="🌍 Language Translation (CPU-friendly)"
67
- )
68
-
69
- iface.launch(share=True)