Spaces:
Sleeping
Sleeping
File size: 10,023 Bytes
f52a8f9 9639dd1 f52a8f9 fb10678 f52a8f9 fb10678 f52a8f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
import os
import gc
import time
import gradio as gr
import torch
from PIL import Image
# -----------------------
# Device + CPU perf knobs
# -----------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# Threads (tune for HF CPU Space)
os.environ.setdefault("OMP_NUM_THREADS", "4")
os.environ.setdefault("MKL_NUM_THREADS", "4")
torch.set_num_threads(int(os.environ["OMP_NUM_THREADS"]))
torch.set_num_interop_threads(max(1, int(int(os.environ["OMP_NUM_THREADS"]) // 2)))
INFER = torch.inference_mode if hasattr(torch, "inference_mode") else torch.no_grad
# -----------------------
# Stable Diffusion 1.5 (img2img) for style transfer
# -----------------------
from diffusers import StableDiffusionImg2ImgPipeline, EulerAncestralDiscreteScheduler
def load_sd15_pipe():
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
safety_checker=None,
requires_safety_checker=False,
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
if device == "cuda":
pipe.unet.to(memory_format=torch.channels_last)
return pipe
_sd_pipe = None
def sd_style_transfer(input_image, prompt, strength=0.55, guidance=5.5, steps=18, width=512, height=512, seed=0):
global _sd_pipe
if input_image is None:
raise gr.Error("Please upload an input image.")
if not prompt or not prompt.strip():
raise gr.Error("Please provide a style prompt.")
if _sd_pipe is None:
t0 = time.time()
_sd_pipe = load_sd15_pipe()
print(f"[SD] Pipeline loaded in {time.time()-t0:.2f}s on {device}.", flush=True)
generator = torch.Generator(device=device) if device == "cuda" else torch.Generator()
if isinstance(seed, (int, float)) and int(seed) > 0:
generator = generator.manual_seed(int(seed))
img = input_image.convert("RGB").resize((int(width), int(height)), Image.LANCZOS)
with INFER():
out = _sd_pipe(
prompt=str(prompt),
image=img,
strength=float(strength),
guidance_scale=float(guidance),
num_inference_steps=int(steps),
generator=generator,
).images[0]
if device == "cuda":
torch.cuda.empty_cache()
gc.collect()
return out
# -----------------------
# Grammar correction models
# T5-small (prithivida), T5-base (vennify), GECToR (optional), Llama-3.1-8B-GEC (GGUF)
# -----------------------
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
T5_SMALL = "prithivida/grammar_error_correcter_v1" # T5-small
T5_BASE = "vennify/t5-base-grammar-correction" # T5-base
_t5_tok = {}
_t5_mdl = {}
def load_t5(model_name: str):
if model_name not in _t5_mdl:
tok = AutoTokenizer.from_pretrained(model_name)
mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
_t5_tok[model_name] = tok
_t5_mdl[model_name] = mdl
return _t5_tok[model_name], _t5_mdl[model_name]
def t5_correct(text: str, model_name: str, max_new_tokens=128):
tok, mdl = load_t5(model_name)
prefix = "gec: " if "prithivida" in model_name else "grammar: "
inputs = tok(prefix + text, return_tensors="pt").to(device)
with INFER():
out = mdl.generate(**inputs, max_length=max_new_tokens)
return tok.decode(out[0], skip_special_tokens=True)
# ---- Optional: GECToR (lazy load) ----
_gector_predictor = None
_gector_error = None
_gector_tried = False
def try_load_gector():
global _gector_predictor, _gector_error, _gector_tried
if _gector_tried:
return _gector_predictor, _gector_error
_gector_tried = True
try:
from gector.gec_model import GECModel # requires allennlp + pretrained artifacts
model_paths = os.environ.get("GEC_MODEL_PATHS", "").strip()
vocab_path = os.environ.get("GEC_VOCAB_PATH", "").strip()
if not model_paths or not vocab_path:
raise RuntimeError(
"GECToR selected but model artifacts are not configured. "
"Set GEC_MODEL_PATHS (space-separated .th files) and GEC_VOCAB_PATH (vocab dir)."
)
taggers = model_paths.split()
_gector_predictor = GECModel(
model_paths=taggers,
vocab_path=vocab_path,
device=("cuda" if device == "cuda" else "cpu"),
min_error_probability=0.0,
confidence=0.0,
iterations=2,
special_tokens_fix=1,
)
except Exception as e:
_gector_error = str(e)
_gector_predictor = None
return _gector_predictor, _gector_error
def gector_correct(text: str):
predictor, err = try_load_gector()
if err or predictor is None:
return f"[GECToR not active] {err or 'Unknown error.'}\n" \
f"Enable by setting GEC_MODEL_PATHS and GEC_VOCAB_PATH to pretrained files."
tokens = text.strip().split()
corrected = predictor.handle_batch([tokens])[0]
return " ".join(corrected)
# ---- Llama-3.1-8B GEC (GGUF via llama-cpp-python) ----
_llama_model = None
_llama_err = None
_llama_tried = False
# Choose a sensible quant filename; adjust if you upload a different one to your Space.
LLAMA_REPO = "mradermacher/Llama-3.1-8B-Instruct-Grammatical-Error-Correction-2-GGUF"
LLAMA_FILE = os.environ.get("LLAMA_GGUF_FILE", "llama-3.1-8b-instruct-gec.Q4_K_S.gguf")
def try_load_llama():
global _llama_model, _llama_err, _llama_tried
if _llama_tried:
return _llama_model, _llama_err
_llama_tried = True
try:
from llama_cpp import Llama
# Load directly from Hub (no need to manually download)
_llama_model = Llama.from_pretrained(
repo_id=LLAMA_REPO,
filename=LLAMA_FILE,
n_ctx=2048,
n_threads=int(os.environ.get("OMP_NUM_THREADS", "4")),
n_batch=128,
verbose=False
)
except Exception as e:
_llama_model = None
_llama_err = str(e)
return _llama_model, _llama_err
def llama_gec_correct(text: str, max_new_tokens=256):
mdl, err = try_load_llama()
if err or mdl is None:
return f"[Llama GGUF not active] {err or 'Unknown error.'}\n" \
f"Check model availability or set LLAMA_GGUF_FILE to a valid filename."
prompt = (
"You are a precise grammatical error corrector. "
"Return only the corrected text without explanations.\n\n"
f"Input: {text}\n"
"Corrected:"
)
out = mdl(prompt, max_tokens=max_new_tokens, stop=["\n\n", "\nCorrected:"])
return out["choices"][0]["text"].strip()
# -----------------------
# Router
# -----------------------
MODEL_OPTIONS = [
"T5-small (prithivida)",
"T5-base (vennify)",
"GECToR (tagging)",
"Llama-3.1-8B-GEC (GGUF)"
]
def correct_text_router(text: str, model_choice: str, max_new_tokens=128):
text = (text or "").strip()
if not text:
raise gr.Error("Please enter text to correct.")
if model_choice == "T5-small (prithivida)":
return t5_correct(text, T5_SMALL, max_new_tokens=max_new_tokens)
if model_choice == "T5-base (vennify)":
return t5_correct(text, T5_BASE, max_new_tokens=max_new_tokens)
if model_choice == "GECToR (tagging)":
return gector_correct(text)
if model_choice == "Llama-3.1-8B-GEC (GGUF)":
return llama_gec_correct(text, max_new_tokens=max_new_tokens)
return "Unknown model selection."
# -----------------------
# UI
# -----------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"# 🎨 Style transfer (SD 1.5 img2img) + ✍️ English correction\n"
f"- Device detected: **{device.upper()}**\n"
f"- Models: T5-small, T5-base, GECToR, Llama-3.1-8B-GEC (GGUF)\n"
)
with gr.Tab("Image style transfer"):
with gr.Row():
img_in = gr.Image(label="Input image", type="pil")
img_out = gr.Image(label="Styled output")
prompt = gr.Textbox(label="Style prompt", placeholder="e.g., watercolor wash, halftone dots, 1960s comic shading")
with gr.Row():
strength = gr.Slider(0.1, 0.95, value=0.55, step=0.05, label="Style strength")
guidance = gr.Slider(1.0, 12.0, value=5.5, step=0.5, label="Guidance")
steps = gr.Slider(5, 40, value=18, step=1, label="Steps")
with gr.Row():
width = gr.Slider(256, 768, value=512, step=64, label="Width")
height = gr.Slider(256, 768, value=512, step=64, label="Height")
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
run_btn = gr.Button("Transfer style", variant="primary")
run_btn.click(
fn=sd_style_transfer,
inputs=[img_in, prompt, strength, guidance, steps, width, height, seed],
outputs=[img_out]
)
with gr.Tab("English grammar correction"):
model_choice = gr.Dropdown(MODEL_OPTIONS, value="T5-small (prithivida)", label="Model")
txt_in = gr.Textbox(lines=6, label="Input text")
max_new = gr.Slider(32, 512, value=128, step=16, label="Max tokens (generation models)")
txt_out = gr.Textbox(lines=6, label="Corrected text")
corr_btn = gr.Button("Correct", variant="primary")
corr_btn.click(
fn=correct_text_router,
inputs=[txt_in, model_choice, max_new],
outputs=[txt_out]
)
gr.Markdown(
"Tips:\n"
"- On CPU: steps 12–20, guidance 4–7, 512×512 for SD speed.\n"
"- T5-small = fastest, T5-base = more accurate.\n"
"- GECToR needs AllenNLP and pretrained tagger files (set GEC_MODEL_PATHS & GEC_VOCAB_PATH).\n"
"- Llama GGUF loads from Hub (Q4_K_S by default). Adjust LLAMA_GGUF_FILE if needed."
)
if __name__ == "__main__":
demo.launch()
|