tinyInstruct / app.py
AItool's picture
Update app.py
f52a8f9 verified
raw
history blame
10 kB
import os
import gc
import time
import gradio as gr
import torch
from PIL import Image
# -----------------------
# Device + CPU perf knobs
# -----------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# Threads (tune for HF CPU Space)
os.environ.setdefault("OMP_NUM_THREADS", "4")
os.environ.setdefault("MKL_NUM_THREADS", "4")
torch.set_num_threads(int(os.environ["OMP_NUM_THREADS"]))
torch.set_num_interop_threads(max(1, int(int(os.environ["OMP_NUM_THREADS"]) // 2)))
INFER = torch.inference_mode if hasattr(torch, "inference_mode") else torch.no_grad
# -----------------------
# Stable Diffusion 1.5 (img2img) for style transfer
# -----------------------
from diffusers import StableDiffusionImg2ImgPipeline, EulerAncestralDiscreteScheduler
def load_sd15_pipe():
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
safety_checker=None,
requires_safety_checker=False,
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
if device == "cuda":
pipe.unet.to(memory_format=torch.channels_last)
return pipe
_sd_pipe = None
def sd_style_transfer(input_image, prompt, strength=0.55, guidance=5.5, steps=18, width=512, height=512, seed=0):
global _sd_pipe
if input_image is None:
raise gr.Error("Please upload an input image.")
if not prompt or not prompt.strip():
raise gr.Error("Please provide a style prompt.")
if _sd_pipe is None:
t0 = time.time()
_sd_pipe = load_sd15_pipe()
print(f"[SD] Pipeline loaded in {time.time()-t0:.2f}s on {device}.", flush=True)
generator = torch.Generator(device=device) if device == "cuda" else torch.Generator()
if isinstance(seed, (int, float)) and int(seed) > 0:
generator = generator.manual_seed(int(seed))
img = input_image.convert("RGB").resize((int(width), int(height)), Image.LANCZOS)
with INFER():
out = _sd_pipe(
prompt=str(prompt),
image=img,
strength=float(strength),
guidance_scale=float(guidance),
num_inference_steps=int(steps),
generator=generator,
).images[0]
if device == "cuda":
torch.cuda.empty_cache()
gc.collect()
return out
# -----------------------
# Grammar correction models
# T5-small (prithivida), T5-base (vennify), GECToR (optional), Llama-3.1-8B-GEC (GGUF)
# -----------------------
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
T5_SMALL = "prithivida/grammar_error_correcter_v1" # T5-small
T5_BASE = "vennify/t5-base-grammar-correction" # T5-base
_t5_tok = {}
_t5_mdl = {}
def load_t5(model_name: str):
if model_name not in _t5_mdl:
tok = AutoTokenizer.from_pretrained(model_name)
mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
_t5_tok[model_name] = tok
_t5_mdl[model_name] = mdl
return _t5_tok[model_name], _t5_mdl[model_name]
def t5_correct(text: str, model_name: str, max_new_tokens=128):
tok, mdl = load_t5(model_name)
prefix = "gec: " if "prithivida" in model_name else "grammar: "
inputs = tok(prefix + text, return_tensors="pt").to(device)
with INFER():
out = mdl.generate(**inputs, max_length=max_new_tokens)
return tok.decode(out[0], skip_special_tokens=True)
# ---- Optional: GECToR (lazy load) ----
_gector_predictor = None
_gector_error = None
_gector_tried = False
def try_load_gector():
global _gector_predictor, _gector_error, _gector_tried
if _gector_tried:
return _gector_predictor, _gector_error
_gector_tried = True
try:
from gector.gec_model import GECModel # requires allennlp + pretrained artifacts
model_paths = os.environ.get("GEC_MODEL_PATHS", "").strip()
vocab_path = os.environ.get("GEC_VOCAB_PATH", "").strip()
if not model_paths or not vocab_path:
raise RuntimeError(
"GECToR selected but model artifacts are not configured. "
"Set GEC_MODEL_PATHS (space-separated .th files) and GEC_VOCAB_PATH (vocab dir)."
)
taggers = model_paths.split()
_gector_predictor = GECModel(
model_paths=taggers,
vocab_path=vocab_path,
device=("cuda" if device == "cuda" else "cpu"),
min_error_probability=0.0,
confidence=0.0,
iterations=2,
special_tokens_fix=1,
)
except Exception as e:
_gector_error = str(e)
_gector_predictor = None
return _gector_predictor, _gector_error
def gector_correct(text: str):
predictor, err = try_load_gector()
if err or predictor is None:
return f"[GECToR not active] {err or 'Unknown error.'}\n" \
f"Enable by setting GEC_MODEL_PATHS and GEC_VOCAB_PATH to pretrained files."
tokens = text.strip().split()
corrected = predictor.handle_batch([tokens])[0]
return " ".join(corrected)
# ---- Llama-3.1-8B GEC (GGUF via llama-cpp-python) ----
_llama_model = None
_llama_err = None
_llama_tried = False
# Choose a sensible quant filename; adjust if you upload a different one to your Space.
LLAMA_REPO = "mradermacher/Llama-3.1-8B-Instruct-Grammatical-Error-Correction-2-GGUF"
LLAMA_FILE = os.environ.get("LLAMA_GGUF_FILE", "llama-3.1-8b-instruct-gec.Q4_K_S.gguf")
def try_load_llama():
global _llama_model, _llama_err, _llama_tried
if _llama_tried:
return _llama_model, _llama_err
_llama_tried = True
try:
from llama_cpp import Llama
# Load directly from Hub (no need to manually download)
_llama_model = Llama.from_pretrained(
repo_id=LLAMA_REPO,
filename=LLAMA_FILE,
n_ctx=2048,
n_threads=int(os.environ.get("OMP_NUM_THREADS", "4")),
n_batch=128,
verbose=False
)
except Exception as e:
_llama_model = None
_llama_err = str(e)
return _llama_model, _llama_err
def llama_gec_correct(text: str, max_new_tokens=256):
mdl, err = try_load_llama()
if err or mdl is None:
return f"[Llama GGUF not active] {err or 'Unknown error.'}\n" \
f"Check model availability or set LLAMA_GGUF_FILE to a valid filename."
prompt = (
"You are a precise grammatical error corrector. "
"Return only the corrected text without explanations.\n\n"
f"Input: {text}\n"
"Corrected:"
)
out = mdl(prompt, max_tokens=max_new_tokens, stop=["\n\n", "\nCorrected:"])
return out["choices"][0]["text"].strip()
# -----------------------
# Router
# -----------------------
MODEL_OPTIONS = [
"T5-small (prithivida)",
"T5-base (vennify)",
"GECToR (tagging)",
"Llama-3.1-8B-GEC (GGUF)"
]
def correct_text_router(text: str, model_choice: str, max_new_tokens=128):
text = (text or "").strip()
if not text:
raise gr.Error("Please enter text to correct.")
if model_choice == "T5-small (prithivida)":
return t5_correct(text, T5_SMALL, max_new_tokens=max_new_tokens)
if model_choice == "T5-base (vennify)":
return t5_correct(text, T5_BASE, max_new_tokens=max_new_tokens)
if model_choice == "GECToR (tagging)":
return gector_correct(text)
if model_choice == "Llama-3.1-8B-GEC (GGUF)":
return llama_gec_correct(text, max_new_tokens=max_new_tokens)
return "Unknown model selection."
# -----------------------
# UI
# -----------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"# 🎨 Style transfer (SD 1.5 img2img) + ✍️ English correction\n"
f"- Device detected: **{device.upper()}**\n"
f"- Models: T5-small, T5-base, GECToR, Llama-3.1-8B-GEC (GGUF)\n"
)
with gr.Tab("Image style transfer"):
with gr.Row():
img_in = gr.Image(label="Input image", type="pil")
img_out = gr.Image(label="Styled output")
prompt = gr.Textbox(label="Style prompt", placeholder="e.g., watercolor wash, halftone dots, 1960s comic shading")
with gr.Row():
strength = gr.Slider(0.1, 0.95, value=0.55, step=0.05, label="Style strength")
guidance = gr.Slider(1.0, 12.0, value=5.5, step=0.5, label="Guidance")
steps = gr.Slider(5, 40, value=18, step=1, label="Steps")
with gr.Row():
width = gr.Slider(256, 768, value=512, step=64, label="Width")
height = gr.Slider(256, 768, value=512, step=64, label="Height")
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
run_btn = gr.Button("Transfer style", variant="primary")
run_btn.click(
fn=sd_style_transfer,
inputs=[img_in, prompt, strength, guidance, steps, width, height, seed],
outputs=[img_out]
)
with gr.Tab("English grammar correction"):
model_choice = gr.Dropdown(MODEL_OPTIONS, value="T5-small (prithivida)", label="Model")
txt_in = gr.Textbox(lines=6, label="Input text")
max_new = gr.Slider(32, 512, value=128, step=16, label="Max tokens (generation models)")
txt_out = gr.Textbox(lines=6, label="Corrected text")
corr_btn = gr.Button("Correct", variant="primary")
corr_btn.click(
fn=correct_text_router,
inputs=[txt_in, model_choice, max_new],
outputs=[txt_out]
)
gr.Markdown(
"Tips:\n"
"- On CPU: steps 12–20, guidance 4–7, 512×512 for SD speed.\n"
"- T5-small = fastest, T5-base = more accurate.\n"
"- GECToR needs AllenNLP and pretrained tagger files (set GEC_MODEL_PATHS & GEC_VOCAB_PATH).\n"
"- Llama GGUF loads from Hub (Q4_K_S by default). Adjust LLAMA_GGUF_FILE if needed."
)
if __name__ == "__main__":
demo.launch()