import os import gc import time import gradio as gr import torch from PIL import Image # ----------------------- # Device + CPU perf knobs # ----------------------- device = "cuda" if torch.cuda.is_available() else "cpu" # Threads (tune for HF CPU Space) os.environ.setdefault("OMP_NUM_THREADS", "4") os.environ.setdefault("MKL_NUM_THREADS", "4") torch.set_num_threads(int(os.environ["OMP_NUM_THREADS"])) torch.set_num_interop_threads(max(1, int(int(os.environ["OMP_NUM_THREADS"]) // 2))) INFER = torch.inference_mode if hasattr(torch, "inference_mode") else torch.no_grad # ----------------------- # Stable Diffusion 1.5 (img2img) for style transfer # ----------------------- from diffusers import StableDiffusionImg2ImgPipeline, EulerAncestralDiscreteScheduler def load_sd15_pipe(): pipe = StableDiffusionImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", safety_checker=None, requires_safety_checker=False, ) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) pipe = pipe.to(device) pipe.enable_attention_slicing() pipe.enable_vae_tiling() pipe.enable_vae_slicing() if device == "cuda": pipe.unet.to(memory_format=torch.channels_last) return pipe _sd_pipe = None def sd_style_transfer(input_image, prompt, strength=0.55, guidance=5.5, steps=18, width=512, height=512, seed=0): global _sd_pipe if input_image is None: raise gr.Error("Please upload an input image.") if not prompt or not prompt.strip(): raise gr.Error("Please provide a style prompt.") if _sd_pipe is None: t0 = time.time() _sd_pipe = load_sd15_pipe() print(f"[SD] Pipeline loaded in {time.time()-t0:.2f}s on {device}.", flush=True) generator = torch.Generator(device=device) if device == "cuda" else torch.Generator() if isinstance(seed, (int, float)) and int(seed) > 0: generator = generator.manual_seed(int(seed)) img = input_image.convert("RGB").resize((int(width), int(height)), Image.LANCZOS) with INFER(): out = _sd_pipe( prompt=str(prompt), image=img, strength=float(strength), guidance_scale=float(guidance), num_inference_steps=int(steps), generator=generator, ).images[0] if device == "cuda": torch.cuda.empty_cache() gc.collect() return out # ----------------------- # Grammar correction models # T5-small (prithivida), T5-base (vennify), GECToR (optional), Llama-3.1-8B-GEC (GGUF) # ----------------------- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM T5_SMALL = "prithivida/grammar_error_correcter_v1" # T5-small T5_BASE = "vennify/t5-base-grammar-correction" # T5-base _t5_tok = {} _t5_mdl = {} def load_t5(model_name: str): if model_name not in _t5_mdl: tok = AutoTokenizer.from_pretrained(model_name) mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) _t5_tok[model_name] = tok _t5_mdl[model_name] = mdl return _t5_tok[model_name], _t5_mdl[model_name] def t5_correct(text: str, model_name: str, max_new_tokens=128): tok, mdl = load_t5(model_name) prefix = "gec: " if "prithivida" in model_name else "grammar: " inputs = tok(prefix + text, return_tensors="pt").to(device) with INFER(): out = mdl.generate(**inputs, max_length=max_new_tokens) return tok.decode(out[0], skip_special_tokens=True) # ---- Optional: GECToR (lazy load) ---- _gector_predictor = None _gector_error = None _gector_tried = False def try_load_gector(): global _gector_predictor, _gector_error, _gector_tried if _gector_tried: return _gector_predictor, _gector_error _gector_tried = True try: from gector.gec_model import GECModel # requires allennlp + pretrained artifacts model_paths = os.environ.get("GEC_MODEL_PATHS", "").strip() vocab_path = os.environ.get("GEC_VOCAB_PATH", "").strip() if not model_paths or not vocab_path: raise RuntimeError( "GECToR selected but model artifacts are not configured. " "Set GEC_MODEL_PATHS (space-separated .th files) and GEC_VOCAB_PATH (vocab dir)." ) taggers = model_paths.split() _gector_predictor = GECModel( model_paths=taggers, vocab_path=vocab_path, device=("cuda" if device == "cuda" else "cpu"), min_error_probability=0.0, confidence=0.0, iterations=2, special_tokens_fix=1, ) except Exception as e: _gector_error = str(e) _gector_predictor = None return _gector_predictor, _gector_error def gector_correct(text: str): predictor, err = try_load_gector() if err or predictor is None: return f"[GECToR not active] {err or 'Unknown error.'}\n" \ f"Enable by setting GEC_MODEL_PATHS and GEC_VOCAB_PATH to pretrained files." tokens = text.strip().split() corrected = predictor.handle_batch([tokens])[0] return " ".join(corrected) # ---- Llama-3.1-8B GEC (GGUF via llama-cpp-python) ---- _llama_model = None _llama_err = None _llama_tried = False # Choose a sensible quant filename; adjust if you upload a different one to your Space. LLAMA_REPO = "mradermacher/Llama-3.1-8B-Instruct-Grammatical-Error-Correction-2-GGUF" LLAMA_FILE = os.environ.get("LLAMA_GGUF_FILE", "llama-3.1-8b-instruct-gec.Q4_K_S.gguf") def try_load_llama(): global _llama_model, _llama_err, _llama_tried if _llama_tried: return _llama_model, _llama_err _llama_tried = True try: from llama_cpp import Llama # Load directly from Hub (no need to manually download) _llama_model = Llama.from_pretrained( repo_id=LLAMA_REPO, filename=LLAMA_FILE, n_ctx=2048, n_threads=int(os.environ.get("OMP_NUM_THREADS", "4")), n_batch=128, verbose=False ) except Exception as e: _llama_model = None _llama_err = str(e) return _llama_model, _llama_err def llama_gec_correct(text: str, max_new_tokens=256): mdl, err = try_load_llama() if err or mdl is None: return f"[Llama GGUF not active] {err or 'Unknown error.'}\n" \ f"Check model availability or set LLAMA_GGUF_FILE to a valid filename." prompt = ( "You are a precise grammatical error corrector. " "Return only the corrected text without explanations.\n\n" f"Input: {text}\n" "Corrected:" ) out = mdl(prompt, max_tokens=max_new_tokens, stop=["\n\n", "\nCorrected:"]) return out["choices"][0]["text"].strip() # ----------------------- # Router # ----------------------- MODEL_OPTIONS = [ "T5-small (prithivida)", "T5-base (vennify)", "GECToR (tagging)", "Llama-3.1-8B-GEC (GGUF)" ] def correct_text_router(text: str, model_choice: str, max_new_tokens=128): text = (text or "").strip() if not text: raise gr.Error("Please enter text to correct.") if model_choice == "T5-small (prithivida)": return t5_correct(text, T5_SMALL, max_new_tokens=max_new_tokens) if model_choice == "T5-base (vennify)": return t5_correct(text, T5_BASE, max_new_tokens=max_new_tokens) if model_choice == "GECToR (tagging)": return gector_correct(text) if model_choice == "Llama-3.1-8B-GEC (GGUF)": return llama_gec_correct(text, max_new_tokens=max_new_tokens) return "Unknown model selection." # ----------------------- # UI # ----------------------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( f"# 🎨 Style transfer (SD 1.5 img2img) + ✍️ English correction\n" f"- Device detected: **{device.upper()}**\n" f"- Models: T5-small, T5-base, GECToR, Llama-3.1-8B-GEC (GGUF)\n" ) with gr.Tab("Image style transfer"): with gr.Row(): img_in = gr.Image(label="Input image", type="pil") img_out = gr.Image(label="Styled output") prompt = gr.Textbox(label="Style prompt", placeholder="e.g., watercolor wash, halftone dots, 1960s comic shading") with gr.Row(): strength = gr.Slider(0.1, 0.95, value=0.55, step=0.05, label="Style strength") guidance = gr.Slider(1.0, 12.0, value=5.5, step=0.5, label="Guidance") steps = gr.Slider(5, 40, value=18, step=1, label="Steps") with gr.Row(): width = gr.Slider(256, 768, value=512, step=64, label="Width") height = gr.Slider(256, 768, value=512, step=64, label="Height") seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") run_btn = gr.Button("Transfer style", variant="primary") run_btn.click( fn=sd_style_transfer, inputs=[img_in, prompt, strength, guidance, steps, width, height, seed], outputs=[img_out] ) with gr.Tab("English grammar correction"): model_choice = gr.Dropdown(MODEL_OPTIONS, value="T5-small (prithivida)", label="Model") txt_in = gr.Textbox(lines=6, label="Input text") max_new = gr.Slider(32, 512, value=128, step=16, label="Max tokens (generation models)") txt_out = gr.Textbox(lines=6, label="Corrected text") corr_btn = gr.Button("Correct", variant="primary") corr_btn.click( fn=correct_text_router, inputs=[txt_in, model_choice, max_new], outputs=[txt_out] ) gr.Markdown( "Tips:\n" "- On CPU: steps 12–20, guidance 4–7, 512×512 for SD speed.\n" "- T5-small = fastest, T5-base = more accurate.\n" "- GECToR needs AllenNLP and pretrained tagger files (set GEC_MODEL_PATHS & GEC_VOCAB_PATH).\n" "- Llama GGUF loads from Hub (Q4_K_S by default). Adjust LLAMA_GGUF_FILE if needed." ) if __name__ == "__main__": demo.launch()