|
|
import gradio as gr
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
import os
|
|
|
|
|
|
checkpoint_path = "Prince012/anuvaadak-en_hi"
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
tokenizer = None
|
|
|
model = None
|
|
|
|
|
|
def load_model():
|
|
|
global tokenizer, model
|
|
|
try:
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
|
|
|
model.to(device)
|
|
|
model.eval()
|
|
|
print(f"Model loaded successfully on {device}!")
|
|
|
except Exception as e:
|
|
|
print(f"Error loading model: {e}")
|
|
|
|
|
|
def generate_translation(text):
|
|
|
if not isinstance(text, str) or text.strip() == "":
|
|
|
return ""
|
|
|
|
|
|
try:
|
|
|
inputs = tokenizer(text, return_tensors="pt", max_length=128, truncation=True)
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = model.generate(**inputs, max_length=128)
|
|
|
|
|
|
outputs = outputs.cpu()
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
except Exception as e:
|
|
|
print(f"Error during translation: {e}")
|
|
|
return f"Translation error: {str(e)}"
|
|
|
|
|
|
def translate_text(english_text):
|
|
|
if not english_text:
|
|
|
return ""
|
|
|
|
|
|
hindi_translation = generate_translation(english_text)
|
|
|
return hindi_translation
|
|
|
|
|
|
|
|
|
def count_chars(text):
|
|
|
return f"{len(text)} characters"
|
|
|
|
|
|
def clear_inputs():
|
|
|
return ["", ""]
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
load_model()
|
|
|
|
|
|
|
|
|
css = """
|
|
|
body {
|
|
|
font-family: Arial, sans-serif;
|
|
|
}
|
|
|
.container {
|
|
|
max-width: 800px;
|
|
|
margin: 0 auto;
|
|
|
}
|
|
|
h1 {
|
|
|
color: #4169E1;
|
|
|
text-align: center;
|
|
|
font-size: 32px;
|
|
|
margin-bottom: 5px;
|
|
|
}
|
|
|
h2 {
|
|
|
color: #666;
|
|
|
text-align: center;
|
|
|
font-size: 18px;
|
|
|
font-weight: normal;
|
|
|
margin-top: 0;
|
|
|
margin-bottom: 30px;
|
|
|
}
|
|
|
.textbox-label {
|
|
|
font-weight: bold;
|
|
|
margin-bottom: 5px;
|
|
|
}
|
|
|
.gradio-container {
|
|
|
background-color: white;
|
|
|
}
|
|
|
.gr-button-primary {
|
|
|
background-color: #4169E1 !important;
|
|
|
}
|
|
|
footer {
|
|
|
text-align: center;
|
|
|
margin-top: 20px;
|
|
|
color: #888;
|
|
|
font-size: 12px;
|
|
|
}
|
|
|
"""
|
|
|
|
|
|
with gr.Blocks(css=css) as demo:
|
|
|
gr.HTML("""
|
|
|
<div class="container">
|
|
|
<h1>Anuvaadak</h1>
|
|
|
<h2>English to Hindi Translator</h2>
|
|
|
</div>
|
|
|
""")
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
|
gr.HTML('<div class="textbox-label">๐ฌ๐ง English</div>')
|
|
|
english_input = gr.Textbox(
|
|
|
placeholder="Enter English text here...",
|
|
|
lines=6,
|
|
|
max_lines=10,
|
|
|
show_label=False
|
|
|
)
|
|
|
english_char_count = gr.HTML('<div style="text-align: right; color: #888; font-size: 12px;">0 characters</div>')
|
|
|
|
|
|
with gr.Column():
|
|
|
gr.HTML('<div class="textbox-label">๐ฎ๐ณ Hindi</div>')
|
|
|
hindi_output = gr.Textbox(
|
|
|
placeholder="Translation will appear here...",
|
|
|
lines=6,
|
|
|
max_lines=10,
|
|
|
show_label=False,
|
|
|
interactive=False
|
|
|
)
|
|
|
hindi_char_count = gr.HTML('<div style="text-align: right; color: #888; font-size: 12px;">0 characters</div>')
|
|
|
|
|
|
with gr.Row():
|
|
|
translate_btn = gr.Button("Translate", variant="primary")
|
|
|
clear_btn = gr.Button("Clear")
|
|
|
|
|
|
gr.HTML("""
|
|
|
<footer>
|
|
|
ยฉ 2025 DeepTrans. All rights reserved.
|
|
|
</footer>
|
|
|
""")
|
|
|
|
|
|
|
|
|
english_input.change(
|
|
|
count_chars,
|
|
|
inputs=[english_input],
|
|
|
outputs=[english_char_count]
|
|
|
)
|
|
|
|
|
|
|
|
|
def translate_and_count(text):
|
|
|
translation = translate_text(text)
|
|
|
return translation, count_chars(translation)
|
|
|
|
|
|
translate_btn.click(
|
|
|
translate_and_count,
|
|
|
inputs=[english_input],
|
|
|
outputs=[hindi_output, hindi_char_count]
|
|
|
)
|
|
|
|
|
|
|
|
|
def clear_all():
|
|
|
return ["", "", "0 characters", "0 characters"]
|
|
|
|
|
|
clear_btn.click(
|
|
|
clear_all,
|
|
|
inputs=None,
|
|
|
outputs=[english_input, hindi_output, english_char_count, hindi_char_count]
|
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |