BhashaBridge / app.py
Ullas26's picture
Update app.py
ced4885 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
import torch
MODEL_NAME = "sarvamai/sarvam-translate"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Auto device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
model.to(device)
LANG_OPTIONS = [
"Hindi", "Bengali", "Marathi", "Telugu", "Tamil", "Gujarati", "Urdu", "Kannada", "Odia", "Malayalam", "Punjabi", "Assamese", "Maithili", "Santali", "Kashmiri", "Nepali", "Sindhi", "Dogri", "Konkani", "Manipuri (Meitei)", "Bodo", "Sanskrit"
]
@spaces.GPU # enables GPU when available
def generate(tgt_lang, input_txt):
if not input_txt.strip():
return "Please enter text to translate."
messages = [
{"role": "system", "content": f"Translate the following sentence into {tgt_lang}."},
{"role": "user", "content": input_txt},
]
# Convert chat format
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt").to(device)
generated = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.2,
)
output_ids = generated[0][len(inputs.input_ids[0]):]
output = tokenizer.decode(output_ids, skip_special_tokens=True)
return output
demo = gr.Interface(
fn=generate,
inputs=[
gr.Radio(LANG_OPTIONS, label="Target Language", value="Hindi"),
gr.Textbox(label="Input Text", value="Enter the word/text to be translated")
],
outputs=gr.Textbox(label="Translation"),
title="BhashaBridge"
)
if __name__ == "__main__":
demo.launch()