import gradio as gr from diffusers import StableDiffusionImg2ImgPipeline from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch # --- Device selection: CUDA if available, else CPU --- device = "cuda" if torch.cuda.is_available() else "cpu" # --- Load Stable Diffusion for style transfer --- sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5" ) sd_pipe = sd_pipe.to(device) sd_pipe.enable_attention_slicing() # helps reduce memory usage # --- Load T5-base or T5-small for grammar correction --- model_name = "vennify/t5-base-grammar-correction" # or "prithivida/grammar_error_correcter_v1" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) def style_transfer(input_image, prompt, strength=0.5, guidance=7.5): result = sd_pipe( prompt=prompt, image=input_image, strength=strength, guidance_scale=guidance ).images[0] return result def correct_text(text): inputs = tokenizer("gec: " + text, return_tensors="pt").to(device) outputs = model.generate(**inputs, max_length=128) return tokenizer.decode(outputs[0], skip_special_tokens=True) # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("## 🎨 Style Transfer + ✍️ Grammar Correction") with gr.Tab("Image Style Transfer"): img_in = gr.Image(type="pil") prompt = gr.Textbox(label="Style Prompt") img_out = gr.Image() btn1 = gr.Button("Transfer Style") btn1.click(style_transfer, inputs=[img_in, prompt], outputs=img_out) with gr.Tab("Text Correction"): txt_in = gr.Textbox(label="Enter text") txt_out = gr.Textbox(label="Corrected text") btn2 = gr.Button("Correct") btn2.click(correct_text, inputs=txt_in, outputs=txt_out) demo.launch()