Spaces:
Sleeping
Sleeping
File size: 1,883 Bytes
9639dd1 abfe8c9 18a628a abfe8c9 9639dd1 18a628a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gradio as gr
from diffusers import StableDiffusionImg2ImgPipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# --- Device selection: CUDA if available, else CPU ---
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- Load Stable Diffusion for style transfer ---
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5"
)
sd_pipe = sd_pipe.to(device)
sd_pipe.enable_attention_slicing() # helps reduce memory usage
# --- Load T5-base or T5-small for grammar correction ---
model_name = "vennify/t5-base-grammar-correction" # or "prithivida/grammar_error_correcter_v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
def style_transfer(input_image, prompt, strength=0.5, guidance=7.5):
result = sd_pipe(
prompt=prompt,
image=input_image,
strength=strength,
guidance_scale=guidance
).images[0]
return result
def correct_text(text):
inputs = tokenizer("gec: " + text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=128)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("## 🎨 Style Transfer + ✍️ Grammar Correction")
with gr.Tab("Image Style Transfer"):
img_in = gr.Image(type="pil")
prompt = gr.Textbox(label="Style Prompt")
img_out = gr.Image()
btn1 = gr.Button("Transfer Style")
btn1.click(style_transfer, inputs=[img_in, prompt], outputs=img_out)
with gr.Tab("Text Correction"):
txt_in = gr.Textbox(label="Enter text")
txt_out = gr.Textbox(label="Corrected text")
btn2 = gr.Button("Correct")
btn2.click(correct_text, inputs=txt_in, outputs=txt_out)
demo.launch()
|