Balaprime's picture
Update app.py
8aea8f3 verified
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, M2M100ForConditionalGeneration, M2M100Tokenizer
# import gradio as gr
# # Grammar correction model
# tokenizer_gc = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
# model_gc = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
# # Translation model
# model_name = "facebook/m2m100_418M"
# tokenizer_mt = M2M100Tokenizer.from_pretrained(model_name)
# model_mt = M2M100ForConditionalGeneration.from_pretrained(model_name)
# # Function to correct grammar
# def correct_grammar(sentence):
# input_text = f"grammar: {sentence}"
# tokens = tokenizer_gc.encode(input_text, return_tensors="pt")
# outputs = model_gc.generate(tokens, max_length=64, num_beams=4, early_stopping=True)
# corrected = tokenizer_gc.decode(outputs[0], skip_special_tokens=True)
# return corrected
# # Function to translate text
# def translate_text(text, lang_code):
# tokenizer_mt.src_lang = "en"
# encoded = tokenizer_mt(text, return_tensors="pt")
# generated = model_mt.generate(**encoded, forced_bos_token_id=tokenizer_mt.lang_code_to_id[lang_code])
# return tokenizer_mt.decode(generated[0], skip_special_tokens=True)
# # Combined pipeline
# def pipeline(sentence):
# corrected = correct_grammar(sentence)
# grammar_status = "Correct" if corrected.lower() == sentence.lower() else "Corrected"
# spanish = translate_text(corrected, "es")
# arabic = translate_text(corrected, "ar")
# return {
# "Input Sentence": sentence,
# "Grammar Status": grammar_status,
# "Corrected Sentence": corrected,
# "Spanish Translation": spanish,
# "Arabic Translation": arabic
# }
# # Gradio UI
# iface = gr.Interface(
# fn=pipeline,
# inputs="text",
# outputs="json",
# title="Grammar Check + Spanish & Arabic Translator",
# description="This tool checks grammar, corrects it if needed, and translates to Spanish and Arabic."
# )
# if __name__ == "__main__":
# iface.launch()
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import MarianMTModel, MarianTokenizer
import gradio as gr
import torch
# Check device (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load grammar correction model
tokenizer_gc = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
model_gc = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction").to(device)
# Load translation models
# Spanish
sp_model_name = "Helsinki-NLP/opus-mt-en-es"
tokenizer_es = MarianTokenizer.from_pretrained(sp_model_name)
model_es = MarianMTModel.from_pretrained(sp_model_name).to(device)
# Arabic
ar_model_name = "Helsinki-NLP/opus-mt-en-ar"
tokenizer_ar = MarianTokenizer.from_pretrained(ar_model_name)
model_ar = MarianMTModel.from_pretrained(ar_model_name).to(device)
# Grammar correction function
def correct_grammar(sentence):
input_text = f"grammar: {sentence}"
tokens = tokenizer_gc.encode(input_text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model_gc.generate(tokens, max_length=64, num_beams=4, early_stopping=True)
corrected = tokenizer_gc.decode(outputs[0], skip_special_tokens=True)
return corrected
# English to Spanish translation
def translate_to_spanish(text):
batch = tokenizer_es.prepare_seq2seq_batch([text], return_tensors="pt").to(device)
with torch.no_grad():
generated = model_es.generate(**batch)
return tokenizer_es.decode(generated[0], skip_special_tokens=True)
# English to Arabic translation
def translate_to_arabic(text):
batch = tokenizer_ar.prepare_seq2seq_batch([text], return_tensors="pt").to(device)
with torch.no_grad():
generated = model_ar.generate(**batch)
return tokenizer_ar.decode(generated[0], skip_special_tokens=True)
# Combined pipeline
def pipeline(sentence):
corrected = correct_grammar(sentence)
grammar_status = "Correct" if corrected.strip().lower() == sentence.strip().lower() else "Corrected"
spanish = translate_to_spanish(corrected)
arabic = translate_to_arabic(corrected)
return {
"Input Sentence": sentence,
"Grammar Status": grammar_status,
"Corrected Sentence": corrected,
"Spanish Translation": spanish,
"Arabic Translation": arabic
}
# Gradio Interface
iface = gr.Interface(
fn=pipeline,
inputs=gr.Textbox(label="Enter a sentence in English"),
outputs="json",
title="πŸ“ Grammar Checker + 🌐 Translator (Fast Version)",
description="Checks grammar, corrects if needed, and translates the corrected sentence to Spanish and Arabic using optimized models."
)
if __name__ == "__main__":
# Optional warm-up to reduce first-call delay
print("Warming up models...")
_ = pipeline("This is a test sentence.")
iface.launch(debug=True)