Spaces:
Sleeping
Sleeping
| import os | |
| import gc | |
| import time | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| # ----------------------- | |
| # Device + CPU perf knobs | |
| # ----------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Threads (tune for HF CPU Space) | |
| os.environ.setdefault("OMP_NUM_THREADS", "4") | |
| os.environ.setdefault("MKL_NUM_THREADS", "4") | |
| torch.set_num_threads(int(os.environ["OMP_NUM_THREADS"])) | |
| torch.set_num_interop_threads(max(1, int(int(os.environ["OMP_NUM_THREADS"]) // 2))) | |
| INFER = torch.inference_mode if hasattr(torch, "inference_mode") else torch.no_grad | |
| # ----------------------- | |
| # Stable Diffusion 1.5 (img2img) for style transfer | |
| # ----------------------- | |
| from diffusers import StableDiffusionImg2ImgPipeline, EulerAncestralDiscreteScheduler | |
| def load_sd15_pipe(): | |
| pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| ) | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| pipe = pipe.to(device) | |
| pipe.enable_attention_slicing() | |
| pipe.enable_vae_tiling() | |
| pipe.enable_vae_slicing() | |
| if device == "cuda": | |
| pipe.unet.to(memory_format=torch.channels_last) | |
| return pipe | |
| _sd_pipe = None | |
| def sd_style_transfer(input_image, prompt, strength=0.55, guidance=5.5, steps=18, width=512, height=512, seed=0): | |
| global _sd_pipe | |
| if input_image is None: | |
| raise gr.Error("Please upload an input image.") | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Please provide a style prompt.") | |
| if _sd_pipe is None: | |
| t0 = time.time() | |
| _sd_pipe = load_sd15_pipe() | |
| print(f"[SD] Pipeline loaded in {time.time()-t0:.2f}s on {device}.", flush=True) | |
| generator = torch.Generator(device=device) if device == "cuda" else torch.Generator() | |
| if isinstance(seed, (int, float)) and int(seed) > 0: | |
| generator = generator.manual_seed(int(seed)) | |
| img = input_image.convert("RGB").resize((int(width), int(height)), Image.LANCZOS) | |
| with INFER(): | |
| out = _sd_pipe( | |
| prompt=str(prompt), | |
| image=img, | |
| strength=float(strength), | |
| guidance_scale=float(guidance), | |
| num_inference_steps=int(steps), | |
| generator=generator, | |
| ).images[0] | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return out | |
| # ----------------------- | |
| # Grammar correction models | |
| # T5-small (prithivida), T5-base (vennify), GECToR (optional), Llama-3.1-8B-GEC (GGUF) | |
| # ----------------------- | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| T5_SMALL = "prithivida/grammar_error_correcter_v1" # T5-small | |
| T5_BASE = "vennify/t5-base-grammar-correction" # T5-base | |
| _t5_tok = {} | |
| _t5_mdl = {} | |
| def load_t5(model_name: str): | |
| if model_name not in _t5_mdl: | |
| tok = AutoTokenizer.from_pretrained(model_name) | |
| mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) | |
| _t5_tok[model_name] = tok | |
| _t5_mdl[model_name] = mdl | |
| return _t5_tok[model_name], _t5_mdl[model_name] | |
| def t5_correct(text: str, model_name: str, max_new_tokens=128): | |
| tok, mdl = load_t5(model_name) | |
| prefix = "gec: " if "prithivida" in model_name else "grammar: " | |
| inputs = tok(prefix + text, return_tensors="pt").to(device) | |
| with INFER(): | |
| out = mdl.generate(**inputs, max_length=max_new_tokens) | |
| return tok.decode(out[0], skip_special_tokens=True) | |
| # ---- Optional: GECToR (lazy load) ---- | |
| _gector_predictor = None | |
| _gector_error = None | |
| _gector_tried = False | |
| def try_load_gector(): | |
| global _gector_predictor, _gector_error, _gector_tried | |
| if _gector_tried: | |
| return _gector_predictor, _gector_error | |
| _gector_tried = True | |
| try: | |
| from gector.gec_model import GECModel # requires allennlp + pretrained artifacts | |
| model_paths = os.environ.get("GEC_MODEL_PATHS", "").strip() | |
| vocab_path = os.environ.get("GEC_VOCAB_PATH", "").strip() | |
| if not model_paths or not vocab_path: | |
| raise RuntimeError( | |
| "GECToR selected but model artifacts are not configured. " | |
| "Set GEC_MODEL_PATHS (space-separated .th files) and GEC_VOCAB_PATH (vocab dir)." | |
| ) | |
| taggers = model_paths.split() | |
| _gector_predictor = GECModel( | |
| model_paths=taggers, | |
| vocab_path=vocab_path, | |
| device=("cuda" if device == "cuda" else "cpu"), | |
| min_error_probability=0.0, | |
| confidence=0.0, | |
| iterations=2, | |
| special_tokens_fix=1, | |
| ) | |
| except Exception as e: | |
| _gector_error = str(e) | |
| _gector_predictor = None | |
| return _gector_predictor, _gector_error | |
| def gector_correct(text: str): | |
| predictor, err = try_load_gector() | |
| if err or predictor is None: | |
| return f"[GECToR not active] {err or 'Unknown error.'}\n" \ | |
| f"Enable by setting GEC_MODEL_PATHS and GEC_VOCAB_PATH to pretrained files." | |
| tokens = text.strip().split() | |
| corrected = predictor.handle_batch([tokens])[0] | |
| return " ".join(corrected) | |
| # ---- Llama-3.1-8B GEC (GGUF via llama-cpp-python) ---- | |
| _llama_model = None | |
| _llama_err = None | |
| _llama_tried = False | |
| # Choose a sensible quant filename; adjust if you upload a different one to your Space. | |
| LLAMA_REPO = "mradermacher/Llama-3.1-8B-Instruct-Grammatical-Error-Correction-2-GGUF" | |
| LLAMA_FILE = os.environ.get("LLAMA_GGUF_FILE", "llama-3.1-8b-instruct-gec.Q4_K_S.gguf") | |
| def try_load_llama(): | |
| global _llama_model, _llama_err, _llama_tried | |
| if _llama_tried: | |
| return _llama_model, _llama_err | |
| _llama_tried = True | |
| try: | |
| from llama_cpp import Llama | |
| # Load directly from Hub (no need to manually download) | |
| _llama_model = Llama.from_pretrained( | |
| repo_id=LLAMA_REPO, | |
| filename=LLAMA_FILE, | |
| n_ctx=2048, | |
| n_threads=int(os.environ.get("OMP_NUM_THREADS", "4")), | |
| n_batch=128, | |
| verbose=False | |
| ) | |
| except Exception as e: | |
| _llama_model = None | |
| _llama_err = str(e) | |
| return _llama_model, _llama_err | |
| def llama_gec_correct(text: str, max_new_tokens=256): | |
| mdl, err = try_load_llama() | |
| if err or mdl is None: | |
| return f"[Llama GGUF not active] {err or 'Unknown error.'}\n" \ | |
| f"Check model availability or set LLAMA_GGUF_FILE to a valid filename." | |
| prompt = ( | |
| "You are a precise grammatical error corrector. " | |
| "Return only the corrected text without explanations.\n\n" | |
| f"Input: {text}\n" | |
| "Corrected:" | |
| ) | |
| out = mdl(prompt, max_tokens=max_new_tokens, stop=["\n\n", "\nCorrected:"]) | |
| return out["choices"][0]["text"].strip() | |
| # ----------------------- | |
| # Router | |
| # ----------------------- | |
| MODEL_OPTIONS = [ | |
| "T5-small (prithivida)", | |
| "T5-base (vennify)", | |
| "GECToR (tagging)", | |
| "Llama-3.1-8B-GEC (GGUF)" | |
| ] | |
| def correct_text_router(text: str, model_choice: str, max_new_tokens=128): | |
| text = (text or "").strip() | |
| if not text: | |
| raise gr.Error("Please enter text to correct.") | |
| if model_choice == "T5-small (prithivida)": | |
| return t5_correct(text, T5_SMALL, max_new_tokens=max_new_tokens) | |
| if model_choice == "T5-base (vennify)": | |
| return t5_correct(text, T5_BASE, max_new_tokens=max_new_tokens) | |
| if model_choice == "GECToR (tagging)": | |
| return gector_correct(text) | |
| if model_choice == "Llama-3.1-8B-GEC (GGUF)": | |
| return llama_gec_correct(text, max_new_tokens=max_new_tokens) | |
| return "Unknown model selection." | |
| # ----------------------- | |
| # UI | |
| # ----------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| f"# 🎨 Style transfer (SD 1.5 img2img) + ✍️ English correction\n" | |
| f"- Device detected: **{device.upper()}**\n" | |
| f"- Models: T5-small, T5-base, GECToR, Llama-3.1-8B-GEC (GGUF)\n" | |
| ) | |
| with gr.Tab("Image style transfer"): | |
| with gr.Row(): | |
| img_in = gr.Image(label="Input image", type="pil") | |
| img_out = gr.Image(label="Styled output") | |
| prompt = gr.Textbox(label="Style prompt", placeholder="e.g., watercolor wash, halftone dots, 1960s comic shading") | |
| with gr.Row(): | |
| strength = gr.Slider(0.1, 0.95, value=0.55, step=0.05, label="Style strength") | |
| guidance = gr.Slider(1.0, 12.0, value=5.5, step=0.5, label="Guidance") | |
| steps = gr.Slider(5, 40, value=18, step=1, label="Steps") | |
| with gr.Row(): | |
| width = gr.Slider(256, 768, value=512, step=64, label="Width") | |
| height = gr.Slider(256, 768, value=512, step=64, label="Height") | |
| seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") | |
| run_btn = gr.Button("Transfer style", variant="primary") | |
| run_btn.click( | |
| fn=sd_style_transfer, | |
| inputs=[img_in, prompt, strength, guidance, steps, width, height, seed], | |
| outputs=[img_out] | |
| ) | |
| with gr.Tab("English grammar correction"): | |
| model_choice = gr.Dropdown(MODEL_OPTIONS, value="T5-small (prithivida)", label="Model") | |
| txt_in = gr.Textbox(lines=6, label="Input text") | |
| max_new = gr.Slider(32, 512, value=128, step=16, label="Max tokens (generation models)") | |
| txt_out = gr.Textbox(lines=6, label="Corrected text") | |
| corr_btn = gr.Button("Correct", variant="primary") | |
| corr_btn.click( | |
| fn=correct_text_router, | |
| inputs=[txt_in, model_choice, max_new], | |
| outputs=[txt_out] | |
| ) | |
| gr.Markdown( | |
| "Tips:\n" | |
| "- On CPU: steps 12–20, guidance 4–7, 512×512 for SD speed.\n" | |
| "- T5-small = fastest, T5-base = more accurate.\n" | |
| "- GECToR needs AllenNLP and pretrained tagger files (set GEC_MODEL_PATHS & GEC_VOCAB_PATH).\n" | |
| "- Llama GGUF loads from Hub (Q4_K_S by default). Adjust LLAMA_GGUF_FILE if needed." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |