AItool commited on
Commit
fb10678
·
verified ·
1 Parent(s): 1f223c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -48
app.py CHANGED
@@ -1,52 +1,47 @@
1
  import gradio as gr
2
- from diffusers import StableDiffusionImg2ImgPipeline
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
- import torch
5
-
6
- # --- Device selection: CUDA if available, else CPU ---
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- # --- Load Stable Diffusion for style transfer ---
10
- sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
11
- "runwayml/stable-diffusion-v1-5"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  )
13
- sd_pipe = sd_pipe.to(device)
14
- sd_pipe.enable_attention_slicing() # helps reduce memory usage
15
-
16
- # --- Load T5-base or T5-small for grammar correction ---
17
- model_name = "vennify/t5-base-grammar-correction" # or "prithivida/grammar_error_correcter_v1"
18
- tokenizer = AutoTokenizer.from_pretrained(model_name)
19
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
20
-
21
- def style_transfer(input_image, prompt, strength=0.5, guidance=7.5):
22
- result = sd_pipe(
23
- prompt=prompt,
24
- image=input_image,
25
- strength=strength,
26
- guidance_scale=guidance
27
- ).images[0]
28
- return result
29
-
30
- def correct_text(text):
31
- inputs = tokenizer("gec: " + text, return_tensors="pt").to(device)
32
- outputs = model.generate(**inputs, max_length=128)
33
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
34
-
35
- # --- Gradio UI ---
36
- with gr.Blocks() as demo:
37
- gr.Markdown("## 🎨 Style Transfer + ✍️ Grammar Correction")
38
-
39
- with gr.Tab("Image Style Transfer"):
40
- img_in = gr.Image(type="pil")
41
- prompt = gr.Textbox(label="Style Prompt")
42
- img_out = gr.Image()
43
- btn1 = gr.Button("Transfer Style")
44
- btn1.click(style_transfer, inputs=[img_in, prompt], outputs=img_out)
45
-
46
- with gr.Tab("Text Correction"):
47
- txt_in = gr.Textbox(label="Enter text")
48
- txt_out = gr.Textbox(label="Corrected text")
49
- btn2 = gr.Button("Correct")
50
- btn2.click(correct_text, inputs=txt_in, outputs=txt_out)
51
 
52
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+
4
+ # Available models
5
+ MODEL_OPTIONS = {
6
+ "Prithivida GEC v1": "prithivida/grammar_error_correcter_v1",
7
+ "Hassaanik GEC": "hassaanik/grammar-correction-model",
8
+ "Vennify T5 GEC": "vennify/t5-base-grammar-correction"
9
+ }
10
+
11
+ # Cache loaded pipelines so we don’t reload every time
12
+ loaded_pipelines = {}
13
+
14
+ def get_pipeline(model_id):
15
+ if model_id not in loaded_pipelines:
16
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
18
+ loaded_pipelines[model_id] = pipeline("text2text-generation",
19
+ model=model,
20
+ tokenizer=tokenizer)
21
+ return loaded_pipelines[model_id]
22
+
23
+ def oxford_polish(sentence: str, model_choice: str) -> str:
24
+ model_id = MODEL_OPTIONS[model_choice]
25
+ polisher = get_pipeline(model_id)
26
+ prompt = (
27
+ "Correct this sentence into formal written English, following the Oxford University Style Guide. "
28
+ "Ensure tense matches time expressions (e.g. 'tomorrow' → future, 'yesterday' → past), "
29
+ "use British spelling, apply the Oxford comma, and correct uncountable nouns naturally. "
30
+ "Sentence: " + sentence
31
+ )
32
+ out = polisher(prompt, max_new_tokens=80, do_sample=False)
33
+ return out[0]["generated_text"].strip()
34
+
35
+ # Gradio interface
36
+ demo = gr.Interface(
37
+ fn=oxford_polish,
38
+ inputs=[
39
+ gr.Textbox(lines=2, placeholder="Enter a sentence to correct..."),
40
+ gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Prithivida GEC v1", label="Choose Model")
41
+ ],
42
+ outputs=gr.Textbox(label="Oxford-style Correction"),
43
+ title="Oxford Grammar Polisher",
44
+ description="Test multiple free grammar correction models from Hugging Face Hub with Oxford-style rules."
45
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  demo.launch()