Spaces:
Sleeping
Sleeping
File size: 6,431 Bytes
60ffa1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import torch
import os
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit.processor import IndicProcessor
# Get token from environment variable
token = os.getenv("HUGGINGFACE_HUB_TOKEN")
# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Model configuration - English to Kannada translation
src_lang, tgt_lang = "eng_Latn", "kan_Knda"
model_name = "ai4bharat/indictrans2-en-indic-dist-200M"
# Global variables to store model and tokenizer
model = None
tokenizer = None
ip = None
def load_model():
"""Load the translation model and tokenizer"""
global model, tokenizer, ip
try:
print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
token=token
)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
trust_remote_code=True,
dtype=torch.float16,
token=token
).to(DEVICE)
ip = IndicProcessor(inference=True)
print(f"Model loaded successfully on {DEVICE}")
return True
except Exception as e:
print(f"Error loading model: {str(e)}")
return False
def translate_text(input_text):
"""
Translate input text using the loaded model
Args:
input_text: Single sentence to translate
Returns:
Translated text
"""
if not model or not tokenizer or not ip:
return "β Model not loaded. Please check the model configuration."
if not input_text.strip():
return "Please enter some text to translate."
try:
# Single sentence translation
input_sentences = [input_text.strip()]
if not input_sentences:
return "No valid sentences found."
# Preprocess the input
batch = ip.preprocess_batch(
input_sentences,
src_lang=src_lang,
tgt_lang=tgt_lang,
)
# Tokenize the sentences
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
# Generate translations
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=False,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens
generated_tokens = tokenizer.batch_decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
# Postprocess the translations
translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
# Return single translation
return translations[0] if translations else "Translation failed."
except Exception as e:
return f"β Translation error: {str(e)}"
def create_interface():
"""Create and configure the Gradio interface"""
# Load model on startup
model_loaded = load_model()
if not model_loaded:
# Create a simple error interface
with gr.Blocks(title="Translation App - Error") as demo:
gr.Markdown("## β Model Loading Error")
gr.Markdown("Failed to load the translation model. Please check:")
gr.Markdown("- Your Hugging Face token is set correctly")
gr.Markdown("- You have access to the gated model")
gr.Markdown("- Your internet connection is working")
return demo
# Create the main interface
with gr.Blocks(
title="AI4Bharat IndicTrans2 Translation",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown(
f"""
# π AI4Bharat IndicTrans2 Translation
**Current Configuration:**
- **Source Language:** {src_lang} (English)
- **Target Language:** {tgt_lang} (Kannada)
- **Model:** {model_name}
- **Device:** {DEVICE}
Enter text below to translate from English to Kannada.
""")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label=f"Input Text ({src_lang})",
placeholder="Enter English text to translate...",
lines=5,
max_lines=10
)
with gr.Row():
translate_btn = gr.Button("π Translate", variant="primary")
clear_btn = gr.Button("ποΈ Clear")
with gr.Column():
output_text = gr.Textbox(
label=f"Translation ({tgt_lang})",
lines=5,
max_lines=10,
interactive=False
)
# Example inputs
gr.Markdown("### π Example Inputs:")
examples = [
["Hello, how are you?"],
["I am going to the market today."],
["This is a very beautiful place."],
["Can you help me?"],
]
gr.Examples(
examples=examples,
inputs=[input_text],
outputs=[output_text],
fn=translate_text,
cache_examples=True
)
# Event handlers
translate_btn.click(
fn=translate_text,
inputs=[input_text],
outputs=[output_text]
)
clear_btn.click(
fn=lambda: ("", ""),
outputs=[input_text, output_text]
)
# Add footer
gr.Markdown("---")
return demo
if __name__ == "__main__":
# Create and launch the interface
demo = create_interface()
# Launch the app
demo.launch(
server_name="0.0.0.0", # Allow external connections
server_port=7860, # Default Gradio port
share=False, # Set to True if you want a public link
debug=True,
show_error=True
) |