AItool commited on
Commit
7c83af9
·
verified ·
1 Parent(s): f52a8f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -264
app.py CHANGED
@@ -1,266 +1,63 @@
1
- import os
2
- import gc
3
- import time
4
- import gradio as gr
5
- import torch
6
- from PIL import Image
7
-
8
- # -----------------------
9
- # Device + CPU perf knobs
10
- # -----------------------
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
-
13
- # Threads (tune for HF CPU Space)
14
- os.environ.setdefault("OMP_NUM_THREADS", "4")
15
- os.environ.setdefault("MKL_NUM_THREADS", "4")
16
- torch.set_num_threads(int(os.environ["OMP_NUM_THREADS"]))
17
- torch.set_num_interop_threads(max(1, int(int(os.environ["OMP_NUM_THREADS"]) // 2)))
18
-
19
- INFER = torch.inference_mode if hasattr(torch, "inference_mode") else torch.no_grad
20
-
21
- # -----------------------
22
- # Stable Diffusion 1.5 (img2img) for style transfer
23
- # -----------------------
24
- from diffusers import StableDiffusionImg2ImgPipeline, EulerAncestralDiscreteScheduler
25
-
26
- def load_sd15_pipe():
27
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
28
- "runwayml/stable-diffusion-v1-5",
29
- safety_checker=None,
30
- requires_safety_checker=False,
31
- )
32
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
33
- pipe = pipe.to(device)
34
- pipe.enable_attention_slicing()
35
- pipe.enable_vae_tiling()
36
- pipe.enable_vae_slicing()
37
- if device == "cuda":
38
- pipe.unet.to(memory_format=torch.channels_last)
39
- return pipe
40
-
41
- _sd_pipe = None
42
-
43
- def sd_style_transfer(input_image, prompt, strength=0.55, guidance=5.5, steps=18, width=512, height=512, seed=0):
44
- global _sd_pipe
45
- if input_image is None:
46
- raise gr.Error("Please upload an input image.")
47
- if not prompt or not prompt.strip():
48
- raise gr.Error("Please provide a style prompt.")
49
-
50
- if _sd_pipe is None:
51
- t0 = time.time()
52
- _sd_pipe = load_sd15_pipe()
53
- print(f"[SD] Pipeline loaded in {time.time()-t0:.2f}s on {device}.", flush=True)
54
-
55
- generator = torch.Generator(device=device) if device == "cuda" else torch.Generator()
56
- if isinstance(seed, (int, float)) and int(seed) > 0:
57
- generator = generator.manual_seed(int(seed))
58
-
59
- img = input_image.convert("RGB").resize((int(width), int(height)), Image.LANCZOS)
60
-
61
- with INFER():
62
- out = _sd_pipe(
63
- prompt=str(prompt),
64
- image=img,
65
- strength=float(strength),
66
- guidance_scale=float(guidance),
67
- num_inference_steps=int(steps),
68
- generator=generator,
69
- ).images[0]
70
-
71
- if device == "cuda":
72
- torch.cuda.empty_cache()
73
- gc.collect()
74
- return out
75
-
76
- # -----------------------
77
- # Grammar correction models
78
- # T5-small (prithivida), T5-base (vennify), GECToR (optional), Llama-3.1-8B-GEC (GGUF)
79
- # -----------------------
80
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
81
-
82
- T5_SMALL = "prithivida/grammar_error_correcter_v1" # T5-small
83
- T5_BASE = "vennify/t5-base-grammar-correction" # T5-base
84
-
85
- _t5_tok = {}
86
- _t5_mdl = {}
87
-
88
- def load_t5(model_name: str):
89
- if model_name not in _t5_mdl:
90
- tok = AutoTokenizer.from_pretrained(model_name)
91
- mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
92
- _t5_tok[model_name] = tok
93
- _t5_mdl[model_name] = mdl
94
- return _t5_tok[model_name], _t5_mdl[model_name]
95
-
96
- def t5_correct(text: str, model_name: str, max_new_tokens=128):
97
- tok, mdl = load_t5(model_name)
98
- prefix = "gec: " if "prithivida" in model_name else "grammar: "
99
- inputs = tok(prefix + text, return_tensors="pt").to(device)
100
- with INFER():
101
- out = mdl.generate(**inputs, max_length=max_new_tokens)
102
- return tok.decode(out[0], skip_special_tokens=True)
103
 
104
- # ---- Optional: GECToR (lazy load) ----
105
- _gector_predictor = None
106
- _gector_error = None
107
- _gector_tried = False
108
-
109
- def try_load_gector():
110
- global _gector_predictor, _gector_error, _gector_tried
111
- if _gector_tried:
112
- return _gector_predictor, _gector_error
113
- _gector_tried = True
114
- try:
115
- from gector.gec_model import GECModel # requires allennlp + pretrained artifacts
116
- model_paths = os.environ.get("GEC_MODEL_PATHS", "").strip()
117
- vocab_path = os.environ.get("GEC_VOCAB_PATH", "").strip()
118
- if not model_paths or not vocab_path:
119
- raise RuntimeError(
120
- "GECToR selected but model artifacts are not configured. "
121
- "Set GEC_MODEL_PATHS (space-separated .th files) and GEC_VOCAB_PATH (vocab dir)."
122
- )
123
- taggers = model_paths.split()
124
- _gector_predictor = GECModel(
125
- model_paths=taggers,
126
- vocab_path=vocab_path,
127
- device=("cuda" if device == "cuda" else "cpu"),
128
- min_error_probability=0.0,
129
- confidence=0.0,
130
- iterations=2,
131
- special_tokens_fix=1,
132
- )
133
- except Exception as e:
134
- _gector_error = str(e)
135
- _gector_predictor = None
136
- return _gector_predictor, _gector_error
137
-
138
- def gector_correct(text: str):
139
- predictor, err = try_load_gector()
140
- if err or predictor is None:
141
- return f"[GECToR not active] {err or 'Unknown error.'}\n" \
142
- f"Enable by setting GEC_MODEL_PATHS and GEC_VOCAB_PATH to pretrained files."
143
- tokens = text.strip().split()
144
- corrected = predictor.handle_batch([tokens])[0]
145
- return " ".join(corrected)
146
-
147
- # ---- Llama-3.1-8B GEC (GGUF via llama-cpp-python) ----
148
- _llama_model = None
149
- _llama_err = None
150
- _llama_tried = False
151
-
152
- # Choose a sensible quant filename; adjust if you upload a different one to your Space.
153
- LLAMA_REPO = "mradermacher/Llama-3.1-8B-Instruct-Grammatical-Error-Correction-2-GGUF"
154
- LLAMA_FILE = os.environ.get("LLAMA_GGUF_FILE", "llama-3.1-8b-instruct-gec.Q4_K_S.gguf")
155
-
156
- def try_load_llama():
157
- global _llama_model, _llama_err, _llama_tried
158
- if _llama_tried:
159
- return _llama_model, _llama_err
160
- _llama_tried = True
161
- try:
162
- from llama_cpp import Llama
163
- # Load directly from Hub (no need to manually download)
164
- _llama_model = Llama.from_pretrained(
165
- repo_id=LLAMA_REPO,
166
- filename=LLAMA_FILE,
167
- n_ctx=2048,
168
- n_threads=int(os.environ.get("OMP_NUM_THREADS", "4")),
169
- n_batch=128,
170
- verbose=False
171
- )
172
- except Exception as e:
173
- _llama_model = None
174
- _llama_err = str(e)
175
- return _llama_model, _llama_err
176
-
177
- def llama_gec_correct(text: str, max_new_tokens=256):
178
- mdl, err = try_load_llama()
179
- if err or mdl is None:
180
- return f"[Llama GGUF not active] {err or 'Unknown error.'}\n" \
181
- f"Check model availability or set LLAMA_GGUF_FILE to a valid filename."
182
- prompt = (
183
- "You are a precise grammatical error corrector. "
184
- "Return only the corrected text without explanations.\n\n"
185
- f"Input: {text}\n"
186
- "Corrected:"
187
- )
188
- out = mdl(prompt, max_tokens=max_new_tokens, stop=["\n\n", "\nCorrected:"])
189
- return out["choices"][0]["text"].strip()
190
-
191
- # -----------------------
192
- # Router
193
- # -----------------------
194
- MODEL_OPTIONS = [
195
- "T5-small (prithivida)",
196
- "T5-base (vennify)",
197
- "GECToR (tagging)",
198
- "Llama-3.1-8B-GEC (GGUF)"
199
- ]
200
-
201
- def correct_text_router(text: str, model_choice: str, max_new_tokens=128):
202
- text = (text or "").strip()
203
- if not text:
204
- raise gr.Error("Please enter text to correct.")
205
- if model_choice == "T5-small (prithivida)":
206
- return t5_correct(text, T5_SMALL, max_new_tokens=max_new_tokens)
207
- if model_choice == "T5-base (vennify)":
208
- return t5_correct(text, T5_BASE, max_new_tokens=max_new_tokens)
209
- if model_choice == "GECToR (tagging)":
210
- return gector_correct(text)
211
- if model_choice == "Llama-3.1-8B-GEC (GGUF)":
212
- return llama_gec_correct(text, max_new_tokens=max_new_tokens)
213
- return "Unknown model selection."
214
-
215
- # -----------------------
216
- # UI
217
- # -----------------------
218
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
219
- gr.Markdown(
220
- f"# 🎨 Style transfer (SD 1.5 img2img) + ✍️ English correction\n"
221
- f"- Device detected: **{device.upper()}**\n"
222
- f"- Models: T5-small, T5-base, GECToR, Llama-3.1-8B-GEC (GGUF)\n"
223
- )
224
-
225
- with gr.Tab("Image style transfer"):
226
- with gr.Row():
227
- img_in = gr.Image(label="Input image", type="pil")
228
- img_out = gr.Image(label="Styled output")
229
- prompt = gr.Textbox(label="Style prompt", placeholder="e.g., watercolor wash, halftone dots, 1960s comic shading")
230
- with gr.Row():
231
- strength = gr.Slider(0.1, 0.95, value=0.55, step=0.05, label="Style strength")
232
- guidance = gr.Slider(1.0, 12.0, value=5.5, step=0.5, label="Guidance")
233
- steps = gr.Slider(5, 40, value=18, step=1, label="Steps")
234
- with gr.Row():
235
- width = gr.Slider(256, 768, value=512, step=64, label="Width")
236
- height = gr.Slider(256, 768, value=512, step=64, label="Height")
237
- seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
238
- run_btn = gr.Button("Transfer style", variant="primary")
239
- run_btn.click(
240
- fn=sd_style_transfer,
241
- inputs=[img_in, prompt, strength, guidance, steps, width, height, seed],
242
- outputs=[img_out]
243
- )
244
-
245
- with gr.Tab("English grammar correction"):
246
- model_choice = gr.Dropdown(MODEL_OPTIONS, value="T5-small (prithivida)", label="Model")
247
- txt_in = gr.Textbox(lines=6, label="Input text")
248
- max_new = gr.Slider(32, 512, value=128, step=16, label="Max tokens (generation models)")
249
- txt_out = gr.Textbox(lines=6, label="Corrected text")
250
- corr_btn = gr.Button("Correct", variant="primary")
251
- corr_btn.click(
252
- fn=correct_text_router,
253
- inputs=[txt_in, model_choice, max_new],
254
- outputs=[txt_out]
255
  )
256
-
257
- gr.Markdown(
258
- "Tips:\n"
259
- "- On CPU: steps 12–20, guidance 4–7, 512×512 for SD speed.\n"
260
- "- T5-small = fastest, T5-base = more accurate.\n"
261
- "- GECToR needs AllenNLP and pretrained tagger files (set GEC_MODEL_PATHS & GEC_VOCAB_PATH).\n"
262
- "- Llama GGUF loads from Hub (Q4_K_S by default). Adjust LLAMA_GGUF_FILE if needed."
263
- )
264
-
265
- if __name__ == "__main__":
266
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
+
6
+ # Only the official Google FLAN-T5 models
7
+ MODEL_OPTIONS = {
8
+ "FLAN-T5-small (Google)": "google/flan-t5-small",
9
+ "FLAN-T5-base (Google)": "google/flan-t5-base"
10
+ }
11
+
12
+ # Cache loaded pipelines
13
+ loaded_pipelines = {}
14
+
15
+ def get_pipeline(model_id: str):
16
+ if model_id not in loaded_pipelines:
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(
19
+ model_id,
20
+ low_cpu_mem_usage=True, # CPU optimization
21
+ torch_dtype="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
+ pipe = pipeline("text2text-generation",
24
+ model=model,
25
+ tokenizer=tokenizer,
26
+ device=-1)
27
+ # Warm-up to avoid first-call lag
28
+ _ = pipe("Correct the grammar: test", max_new_tokens=8, do_sample=False)
29
+ loaded_pipelines[model_id] = pipe
30
+ return loaded_pipelines[model_id]
31
+
32
+ def oxford_polish(sentence: str, model_choice: str) -> str:
33
+ model_id = MODEL_OPTIONS[model_choice]
34
+ polisher = get_pipeline(model_id)
35
+
36
+ # Minimal prompt for FLAN-T5
37
+ prompt = f"Correct the grammar and rewrite in formal British English: {sentence}"
38
+ out = polisher(prompt,
39
+ max_new_tokens=60,
40
+ do_sample=False,
41
+ num_beams=2)
42
+ text = out[0]["generated_text"].strip()
43
+
44
+ # Strip accidental echo
45
+ if text.startswith(prompt):
46
+ text = text[len(prompt):].strip()
47
+ return text
48
+
49
+ # Gradio interface
50
+ demo = gr.Interface(
51
+ fn=oxford_polish,
52
+ inputs=[
53
+ gr.Textbox(lines=2, placeholder="Enter a sentence to correct..."),
54
+ gr.Dropdown(choices=list(MODEL_OPTIONS.keys()),
55
+ value="FLAN-T5-base (Google)",
56
+ label="Choose Model")
57
+ ],
58
+ outputs=gr.Textbox(label="Oxford-style Correction"),
59
+ title="Oxford Grammar Polisher",
60
+ description="Compare Google’s official FLAN-T5 small and base models for grammar correction."
61
+ )
62
+
63
+ demo.launch()