Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| # Only the official Google FLAN-T5 models | |
| MODEL_OPTIONS = { | |
| "FLAN-T5-small (Google)": "google/flan-t5-small", | |
| "FLAN-T5-base (Google)": "google/flan-t5-base" | |
| } | |
| # Cache loaded pipelines | |
| loaded_pipelines = {} | |
| def get_pipeline(model_id: str): | |
| if model_id not in loaded_pipelines: | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_id, | |
| low_cpu_mem_usage=True, # CPU optimization | |
| torch_dtype="auto" | |
| ) | |
| pipe = pipeline("text2text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=-1) | |
| # Warm-up to avoid first-call lag | |
| _ = pipe("Correct the grammar: test", max_new_tokens=8, do_sample=False) | |
| loaded_pipelines[model_id] = pipe | |
| return loaded_pipelines[model_id] | |
| def oxford_polish(sentence: str, model_choice: str) -> str: | |
| model_id = MODEL_OPTIONS[model_choice] | |
| polisher = get_pipeline(model_id) | |
| # Minimal prompt for FLAN-T5 | |
| prompt = f"You are an English grammar corrector and teacher. Return the corrected version: {sentence}" | |
| out = polisher(prompt, | |
| max_new_tokens=60, | |
| do_sample=False, | |
| num_beams=2) | |
| text = out[0]["generated_text"].strip() | |
| # Strip accidental echo | |
| if text.startswith(prompt): | |
| text = text[len(prompt):].strip() | |
| return text | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=oxford_polish, | |
| inputs=[ | |
| gr.Textbox(lines=2, placeholder="Enter a sentence to correct..."), | |
| gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), | |
| value="FLAN-T5-base (Google)", | |
| label="Choose Model") | |
| ], | |
| outputs=gr.Textbox(label="Oxford-grammar Correction"), | |
| title="Oxford Grammar Polisher", | |
| description="Compare Google’s official FLAN-T5 small and base models for grammar correction." | |
| ) | |
| demo.launch() | |