shbhro commited on
Commit
81f14ee
·
verified ·
1 Parent(s): 80aa299

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import torch
4
+ import subprocess
5
+ import sys
6
+ import os
7
+
8
+ # --- Configuration ---
9
+ SYLHETI_TO_BN_MODEL = "shbhro/sylhetit5"
10
+ BN_TO_EN_MODEL = "csebuetnlp/banglat5_nmt_bn_en"
11
+ NORMALIZER_REPO = "https://github.com/csebuetnlp/normalizer.git"
12
+
13
+ # --- Helper function to install/import normalizer ---
14
+ # This ensures the normalizer is available.
15
+ # In HF Spaces, requirements.txt is the primary method.
16
+ normalizer_module = None
17
+ try:
18
+ from normalizer import normalize as normalize_fn_imported
19
+ normalizer_module = normalize_fn_imported
20
+ print("Normalizer imported successfully.")
21
+ except ImportError:
22
+ print(f"Normalizer library not found. Attempting to install from {NORMALIZER_REPO}...")
23
+ try:
24
+ # This command installs the package directly from git.
25
+ # The #egg=normalizer part helps pip identify the package name.
26
+ subprocess.check_call([sys.executable, "-m", "pip", "install", f"git+{NORMALIZER_REPO}#egg=normalizer"])
27
+ from normalizer import normalize as normalize_fn_imported_after_install
28
+ normalizer_module = normalize_fn_imported_after_install
29
+ print("Normalizer installed and imported successfully after pip install.")
30
+ except Exception as e:
31
+ print(f"Failed to install or import normalizer: {e}")
32
+ print("Please ensure 'git+https://github.com/csebuetnlp/normalizer.git#egg=normalizer' is in your requirements.txt for Hugging Face Spaces.")
33
+ # Fallback to a dummy function if installation fails, so the app can still load and show an error.
34
+ def dummy_normalize(text):
35
+ raise RuntimeError("Normalizer library could not be loaded. Please check installation.")
36
+ normalizer_module = dummy_normalize
37
+
38
+ # --- Model Loading (Globally, when the script starts) ---
39
+ sylheti_to_bn_pipe = None
40
+ bn_to_en_model = None
41
+ bn_to_en_tokenizer = None
42
+ model_device = None
43
+
44
+ print("Loading translation models...")
45
+ try:
46
+ model_device_type = "cuda" if torch.cuda.is_available() else "cpu"
47
+ model_device = torch.device(model_device_type)
48
+ hf_device_param = 0 if model_device_type == "cuda" else -1 # For pipeline
49
+
50
+ print(f"Using device: {model_device_type}")
51
+
52
+ sylheti_to_bn_pipe = pipeline(
53
+ "text2text-generation",
54
+ model=SYLHETI_TO_BN_MODEL,
55
+ device=hf_device_param
56
+ )
57
+ print(f"Sylheti-to-Bengali model ({SYLHETI_TO_BN_MODEL}) loaded.")
58
+
59
+ bn_to_en_model = AutoModelForSeq2SeqLM.from_pretrained(BN_TO_EN_MODEL)
60
+ bn_to_en_tokenizer = AutoTokenizer.from_pretrained(BN_TO_EN_MODEL, use_fast=False)
61
+ bn_to_en_model.to(model_device)
62
+ print(f"Bengali-to-English model ({BN_TO_EN_MODEL}) loaded.")
63
+
64
+ except Exception as e:
65
+ print(f"FATAL: Error loading one or more models: {e}")
66
+ # To prevent the app from crashing entirely if models don't load,
67
+ # but it will show errors during translation.
68
+ sylheti_to_bn_pipe = None
69
+ bn_to_en_model = None
70
+ bn_to_en_tokenizer = None
71
+
72
+ # --- Main Translation Logic ---
73
+ def translate_sylheti_to_english_gradio(sylheti_text_input):
74
+ if not sylheti_text_input.strip():
75
+ return "Please enter some Sylheti text.", ""
76
+
77
+ if not sylheti_to_bn_pipe:
78
+ return "Error: Sylheti-to-Bengali model not loaded. Check logs.", ""
79
+ if not bn_to_en_model or not bn_to_en_tokenizer:
80
+ return "Error: Bengali-to-English model not loaded. Check logs.", ""
81
+ if normalizer_module is None or isinstance(normalizer_module, type(lambda:0)) and normalizer_module.__name__ == 'dummy_normalize': # Check if it's the dummy
82
+ return "Error: Bengali normalizer library not available. Check logs.", ""
83
+
84
+
85
+ bengali_text_intermediate = "Error in Sylheti to Bengali step."
86
+ english_text_final = "Error in Bengali to English step."
87
+
88
+ # Step 1: Sylheti → Bengali
89
+ try:
90
+ print(f"Translating Sylheti to Bengali: '{sylheti_text_input}'")
91
+ bengali_translation_outputs = sylheti_to_bn_pipe(
92
+ sylheti_text_input,
93
+ max_length=128,
94
+ num_beams=5,
95
+ early_stopping=True
96
+ )
97
+ bengali_text_intermediate = bengali_translation_outputs[0]['generated_text']
98
+ print(f"Intermediate Bengali: '{bengali_text_intermediate}'")
99
+ except Exception as e:
100
+ print(f"Error during Sylheti to Bengali translation: {e}")
101
+ bengali_text_intermediate = f"Sylheti->Bengali Error: {str(e)}"
102
+ return bengali_text_intermediate, english_text_final # Stop if first step fails
103
+
104
+ # Step 2: Bengali → English
105
+ try:
106
+ print(f"Normalizing and translating Bengali to English: '{bengali_text_intermediate}'")
107
+ normalized_bn_text = normalizer_module(bengali_text_intermediate)
108
+ print(f"Normalized Bengali: '{normalized_bn_text}'")
109
+
110
+ input_ids = bn_to_en_tokenizer(
111
+ normalized_bn_text,
112
+ return_tensors="pt"
113
+ ).input_ids.to(model_device) # Ensure tensor is on the same device
114
+
115
+ generated_tokens = bn_to_en_model.generate(
116
+ input_ids,
117
+ max_length=128,
118
+ num_beams=5,
119
+ early_stopping=True
120
+ )
121
+ english_text_list = bn_to_en_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
122
+ english_text_final = english_text_list[0] if english_text_list else "No English output generated."
123
+ print(f"Final English: '{english_text_final}'")
124
+ except Exception as e:
125
+ print(f"Error during Bengali to English translation: {e}")
126
+ english_text_final = f"Bengali->English Error: {str(e)}"
127
+
128
+ return bengali_text_intermediate, english_text_final
129
+
130
+ # --- Gradio Interface Definition ---
131
+ iface = gr.Interface(
132
+ fn=translate_sylheti_to_english_gradio,
133
+ inputs=gr.Textbox(
134
+ lines=4,
135
+ label="Enter Sylheti Text",
136
+ placeholder="কিতা কিতা কিনলায় তে?"
137
+ ),
138
+ outputs=[
139
+ gr.Textbox(label="Intermediate Bengali Output", lines=4),
140
+ gr.Textbox(label="Final English Output", lines=4)
141
+ ],
142
+ title="🌍 Sylheti to English Translator (via Bengali)",
143
+ description=(
144
+ "Translates Sylheti text to English in two steps:\n"
145
+ f"1. Sylheti → Bengali (using `{SYLHETI_TO_BN_MODEL}`)\n"
146
+ f"2. Bengali → English (using `{BN_TO_EN_MODEL}` with text normalization from `{NORMALIZER_REPO.split('/')[-1]}`)"
147
+ ),
148
+ examples=[
149
+ ["কিতা কিতা কিনলায় তে?"],
150
+ ["তুমি কিতা কররায়?"],
151
+ ["আমি ভাত খাইছি।"],
152
+ ["আফনে ভালা আছনি?"]
153
+ ],
154
+ allow_flagging="never",
155
+ theme=gr.themes.Soft() # Optional: adds a bit of styling
156
+ )
157
+
158
+ # --- Launch the Gradio app ---
159
+ if __name__ == "__main__":
160
+ # When running locally, this launches the server.
161
+ # In Hugging Face Spaces, the `app.py` is typically run by their infrastructure.
162
+ iface.launch()