AItool commited on
Commit
f52a8f9
·
verified ·
1 Parent(s): d2b6695

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -44
app.py CHANGED
@@ -1,47 +1,266 @@
 
 
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
-
4
- # Available models
5
- MODEL_OPTIONS = {
6
- "Prithivida GEC v1": "prithivida/grammar_error_correcter_v1",
7
- "Hassaanik GEC": "hassaanik/grammar-correction-model",
8
- "Vennify T5 GEC": "vennify/t5-base-grammar-correction"
9
- }
10
-
11
- # Cache loaded pipelines so we don’t reload every time
12
- loaded_pipelines = {}
13
-
14
- def get_pipeline(model_id):
15
- if model_id not in loaded_pipelines:
16
- tokenizer = AutoTokenizer.from_pretrained(model_id)
17
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
18
- loaded_pipelines[model_id] = pipeline("text2text-generation",
19
- model=model,
20
- tokenizer=tokenizer)
21
- return loaded_pipelines[model_id]
22
-
23
- def oxford_polish(sentence: str, model_choice: str) -> str:
24
- model_id = MODEL_OPTIONS[model_choice]
25
- polisher = get_pipeline(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  prompt = (
27
- "Correct this sentence into formal written English, following the Oxford University Style Guide. "
28
- "Ensure tense matches time expressions (e.g. 'tomorrow' → future, 'yesterday' → past), "
29
- "use British spelling, apply the Oxford comma, and correct uncountable nouns naturally. "
30
- "Sentence: " + sentence
31
  )
32
- out = polisher(prompt, max_new_tokens=80, do_sample=False)
33
- return out[0]["generated_text"].strip()
34
-
35
- # Gradio interface
36
- demo = gr.Interface(
37
- fn=oxford_polish,
38
- inputs=[
39
- gr.Textbox(lines=2, placeholder="Enter a sentence to correct..."),
40
- gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Prithivida GEC v1", label="Choose Model")
41
- ],
42
- outputs=gr.Textbox(label="Oxford-style Correction"),
43
- title="Oxford Grammar Polisher",
44
- description="Test multiple free grammar correction models from Hugging Face Hub with Oxford-grammar rules."
45
- )
46
-
47
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
  import gradio as gr
5
+ import torch
6
+ from PIL import Image
7
+
8
+ # -----------------------
9
+ # Device + CPU perf knobs
10
+ # -----------------------
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # Threads (tune for HF CPU Space)
14
+ os.environ.setdefault("OMP_NUM_THREADS", "4")
15
+ os.environ.setdefault("MKL_NUM_THREADS", "4")
16
+ torch.set_num_threads(int(os.environ["OMP_NUM_THREADS"]))
17
+ torch.set_num_interop_threads(max(1, int(int(os.environ["OMP_NUM_THREADS"]) // 2)))
18
+
19
+ INFER = torch.inference_mode if hasattr(torch, "inference_mode") else torch.no_grad
20
+
21
+ # -----------------------
22
+ # Stable Diffusion 1.5 (img2img) for style transfer
23
+ # -----------------------
24
+ from diffusers import StableDiffusionImg2ImgPipeline, EulerAncestralDiscreteScheduler
25
+
26
+ def load_sd15_pipe():
27
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
28
+ "runwayml/stable-diffusion-v1-5",
29
+ safety_checker=None,
30
+ requires_safety_checker=False,
31
+ )
32
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
33
+ pipe = pipe.to(device)
34
+ pipe.enable_attention_slicing()
35
+ pipe.enable_vae_tiling()
36
+ pipe.enable_vae_slicing()
37
+ if device == "cuda":
38
+ pipe.unet.to(memory_format=torch.channels_last)
39
+ return pipe
40
+
41
+ _sd_pipe = None
42
+
43
+ def sd_style_transfer(input_image, prompt, strength=0.55, guidance=5.5, steps=18, width=512, height=512, seed=0):
44
+ global _sd_pipe
45
+ if input_image is None:
46
+ raise gr.Error("Please upload an input image.")
47
+ if not prompt or not prompt.strip():
48
+ raise gr.Error("Please provide a style prompt.")
49
+
50
+ if _sd_pipe is None:
51
+ t0 = time.time()
52
+ _sd_pipe = load_sd15_pipe()
53
+ print(f"[SD] Pipeline loaded in {time.time()-t0:.2f}s on {device}.", flush=True)
54
+
55
+ generator = torch.Generator(device=device) if device == "cuda" else torch.Generator()
56
+ if isinstance(seed, (int, float)) and int(seed) > 0:
57
+ generator = generator.manual_seed(int(seed))
58
+
59
+ img = input_image.convert("RGB").resize((int(width), int(height)), Image.LANCZOS)
60
+
61
+ with INFER():
62
+ out = _sd_pipe(
63
+ prompt=str(prompt),
64
+ image=img,
65
+ strength=float(strength),
66
+ guidance_scale=float(guidance),
67
+ num_inference_steps=int(steps),
68
+ generator=generator,
69
+ ).images[0]
70
+
71
+ if device == "cuda":
72
+ torch.cuda.empty_cache()
73
+ gc.collect()
74
+ return out
75
+
76
+ # -----------------------
77
+ # Grammar correction models
78
+ # T5-small (prithivida), T5-base (vennify), GECToR (optional), Llama-3.1-8B-GEC (GGUF)
79
+ # -----------------------
80
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
81
+
82
+ T5_SMALL = "prithivida/grammar_error_correcter_v1" # T5-small
83
+ T5_BASE = "vennify/t5-base-grammar-correction" # T5-base
84
+
85
+ _t5_tok = {}
86
+ _t5_mdl = {}
87
+
88
+ def load_t5(model_name: str):
89
+ if model_name not in _t5_mdl:
90
+ tok = AutoTokenizer.from_pretrained(model_name)
91
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
92
+ _t5_tok[model_name] = tok
93
+ _t5_mdl[model_name] = mdl
94
+ return _t5_tok[model_name], _t5_mdl[model_name]
95
+
96
+ def t5_correct(text: str, model_name: str, max_new_tokens=128):
97
+ tok, mdl = load_t5(model_name)
98
+ prefix = "gec: " if "prithivida" in model_name else "grammar: "
99
+ inputs = tok(prefix + text, return_tensors="pt").to(device)
100
+ with INFER():
101
+ out = mdl.generate(**inputs, max_length=max_new_tokens)
102
+ return tok.decode(out[0], skip_special_tokens=True)
103
+
104
+ # ---- Optional: GECToR (lazy load) ----
105
+ _gector_predictor = None
106
+ _gector_error = None
107
+ _gector_tried = False
108
+
109
+ def try_load_gector():
110
+ global _gector_predictor, _gector_error, _gector_tried
111
+ if _gector_tried:
112
+ return _gector_predictor, _gector_error
113
+ _gector_tried = True
114
+ try:
115
+ from gector.gec_model import GECModel # requires allennlp + pretrained artifacts
116
+ model_paths = os.environ.get("GEC_MODEL_PATHS", "").strip()
117
+ vocab_path = os.environ.get("GEC_VOCAB_PATH", "").strip()
118
+ if not model_paths or not vocab_path:
119
+ raise RuntimeError(
120
+ "GECToR selected but model artifacts are not configured. "
121
+ "Set GEC_MODEL_PATHS (space-separated .th files) and GEC_VOCAB_PATH (vocab dir)."
122
+ )
123
+ taggers = model_paths.split()
124
+ _gector_predictor = GECModel(
125
+ model_paths=taggers,
126
+ vocab_path=vocab_path,
127
+ device=("cuda" if device == "cuda" else "cpu"),
128
+ min_error_probability=0.0,
129
+ confidence=0.0,
130
+ iterations=2,
131
+ special_tokens_fix=1,
132
+ )
133
+ except Exception as e:
134
+ _gector_error = str(e)
135
+ _gector_predictor = None
136
+ return _gector_predictor, _gector_error
137
+
138
+ def gector_correct(text: str):
139
+ predictor, err = try_load_gector()
140
+ if err or predictor is None:
141
+ return f"[GECToR not active] {err or 'Unknown error.'}\n" \
142
+ f"Enable by setting GEC_MODEL_PATHS and GEC_VOCAB_PATH to pretrained files."
143
+ tokens = text.strip().split()
144
+ corrected = predictor.handle_batch([tokens])[0]
145
+ return " ".join(corrected)
146
+
147
+ # ---- Llama-3.1-8B GEC (GGUF via llama-cpp-python) ----
148
+ _llama_model = None
149
+ _llama_err = None
150
+ _llama_tried = False
151
+
152
+ # Choose a sensible quant filename; adjust if you upload a different one to your Space.
153
+ LLAMA_REPO = "mradermacher/Llama-3.1-8B-Instruct-Grammatical-Error-Correction-2-GGUF"
154
+ LLAMA_FILE = os.environ.get("LLAMA_GGUF_FILE", "llama-3.1-8b-instruct-gec.Q4_K_S.gguf")
155
+
156
+ def try_load_llama():
157
+ global _llama_model, _llama_err, _llama_tried
158
+ if _llama_tried:
159
+ return _llama_model, _llama_err
160
+ _llama_tried = True
161
+ try:
162
+ from llama_cpp import Llama
163
+ # Load directly from Hub (no need to manually download)
164
+ _llama_model = Llama.from_pretrained(
165
+ repo_id=LLAMA_REPO,
166
+ filename=LLAMA_FILE,
167
+ n_ctx=2048,
168
+ n_threads=int(os.environ.get("OMP_NUM_THREADS", "4")),
169
+ n_batch=128,
170
+ verbose=False
171
+ )
172
+ except Exception as e:
173
+ _llama_model = None
174
+ _llama_err = str(e)
175
+ return _llama_model, _llama_err
176
+
177
+ def llama_gec_correct(text: str, max_new_tokens=256):
178
+ mdl, err = try_load_llama()
179
+ if err or mdl is None:
180
+ return f"[Llama GGUF not active] {err or 'Unknown error.'}\n" \
181
+ f"Check model availability or set LLAMA_GGUF_FILE to a valid filename."
182
  prompt = (
183
+ "You are a precise grammatical error corrector. "
184
+ "Return only the corrected text without explanations.\n\n"
185
+ f"Input: {text}\n"
186
+ "Corrected:"
187
  )
188
+ out = mdl(prompt, max_tokens=max_new_tokens, stop=["\n\n", "\nCorrected:"])
189
+ return out["choices"][0]["text"].strip()
190
+
191
+ # -----------------------
192
+ # Router
193
+ # -----------------------
194
+ MODEL_OPTIONS = [
195
+ "T5-small (prithivida)",
196
+ "T5-base (vennify)",
197
+ "GECToR (tagging)",
198
+ "Llama-3.1-8B-GEC (GGUF)"
199
+ ]
200
+
201
+ def correct_text_router(text: str, model_choice: str, max_new_tokens=128):
202
+ text = (text or "").strip()
203
+ if not text:
204
+ raise gr.Error("Please enter text to correct.")
205
+ if model_choice == "T5-small (prithivida)":
206
+ return t5_correct(text, T5_SMALL, max_new_tokens=max_new_tokens)
207
+ if model_choice == "T5-base (vennify)":
208
+ return t5_correct(text, T5_BASE, max_new_tokens=max_new_tokens)
209
+ if model_choice == "GECToR (tagging)":
210
+ return gector_correct(text)
211
+ if model_choice == "Llama-3.1-8B-GEC (GGUF)":
212
+ return llama_gec_correct(text, max_new_tokens=max_new_tokens)
213
+ return "Unknown model selection."
214
+
215
+ # -----------------------
216
+ # UI
217
+ # -----------------------
218
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
219
+ gr.Markdown(
220
+ f"# 🎨 Style transfer (SD 1.5 img2img) + ✍️ English correction\n"
221
+ f"- Device detected: **{device.upper()}**\n"
222
+ f"- Models: T5-small, T5-base, GECToR, Llama-3.1-8B-GEC (GGUF)\n"
223
+ )
224
+
225
+ with gr.Tab("Image style transfer"):
226
+ with gr.Row():
227
+ img_in = gr.Image(label="Input image", type="pil")
228
+ img_out = gr.Image(label="Styled output")
229
+ prompt = gr.Textbox(label="Style prompt", placeholder="e.g., watercolor wash, halftone dots, 1960s comic shading")
230
+ with gr.Row():
231
+ strength = gr.Slider(0.1, 0.95, value=0.55, step=0.05, label="Style strength")
232
+ guidance = gr.Slider(1.0, 12.0, value=5.5, step=0.5, label="Guidance")
233
+ steps = gr.Slider(5, 40, value=18, step=1, label="Steps")
234
+ with gr.Row():
235
+ width = gr.Slider(256, 768, value=512, step=64, label="Width")
236
+ height = gr.Slider(256, 768, value=512, step=64, label="Height")
237
+ seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
238
+ run_btn = gr.Button("Transfer style", variant="primary")
239
+ run_btn.click(
240
+ fn=sd_style_transfer,
241
+ inputs=[img_in, prompt, strength, guidance, steps, width, height, seed],
242
+ outputs=[img_out]
243
+ )
244
+
245
+ with gr.Tab("English grammar correction"):
246
+ model_choice = gr.Dropdown(MODEL_OPTIONS, value="T5-small (prithivida)", label="Model")
247
+ txt_in = gr.Textbox(lines=6, label="Input text")
248
+ max_new = gr.Slider(32, 512, value=128, step=16, label="Max tokens (generation models)")
249
+ txt_out = gr.Textbox(lines=6, label="Corrected text")
250
+ corr_btn = gr.Button("Correct", variant="primary")
251
+ corr_btn.click(
252
+ fn=correct_text_router,
253
+ inputs=[txt_in, model_choice, max_new],
254
+ outputs=[txt_out]
255
+ )
256
+
257
+ gr.Markdown(
258
+ "Tips:\n"
259
+ "- On CPU: steps 12–20, guidance 4–7, 512×512 for SD speed.\n"
260
+ "- T5-small = fastest, T5-base = more accurate.\n"
261
+ "- GECToR needs AllenNLP and pretrained tagger files (set GEC_MODEL_PATHS & GEC_VOCAB_PATH).\n"
262
+ "- Llama GGUF loads from Hub (Q4_K_S by default). Adjust LLAMA_GGUF_FILE if needed."
263
+ )
264
+
265
+ if __name__ == "__main__":
266
+ demo.launch()