tinyInstruct / app.py
AItool's picture
Update app.py
abfe8c9 verified
raw
history blame
1.88 kB
import gradio as gr
from diffusers import StableDiffusionImg2ImgPipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# --- Device selection: CUDA if available, else CPU ---
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- Load Stable Diffusion for style transfer ---
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5"
)
sd_pipe = sd_pipe.to(device)
sd_pipe.enable_attention_slicing() # helps reduce memory usage
# --- Load T5-base or T5-small for grammar correction ---
model_name = "vennify/t5-base-grammar-correction" # or "prithivida/grammar_error_correcter_v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
def style_transfer(input_image, prompt, strength=0.5, guidance=7.5):
result = sd_pipe(
prompt=prompt,
image=input_image,
strength=strength,
guidance_scale=guidance
).images[0]
return result
def correct_text(text):
inputs = tokenizer("gec: " + text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=128)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("## 🎨 Style Transfer + ✍️ Grammar Correction")
with gr.Tab("Image Style Transfer"):
img_in = gr.Image(type="pil")
prompt = gr.Textbox(label="Style Prompt")
img_out = gr.Image()
btn1 = gr.Button("Transfer Style")
btn1.click(style_transfer, inputs=[img_in, prompt], outputs=img_out)
with gr.Tab("Text Correction"):
txt_in = gr.Textbox(label="Enter text")
txt_out = gr.Textbox(label="Corrected text")
btn2 = gr.Button("Correct")
btn2.click(correct_text, inputs=txt_in, outputs=txt_out)
demo.launch()