File size: 1,883 Bytes
9639dd1
abfe8c9
 
 
 
 
 
 
 
 
 
18a628a
abfe8c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9639dd1
18a628a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
from diffusers import StableDiffusionImg2ImgPipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# --- Device selection: CUDA if available, else CPU ---
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Load Stable Diffusion for style transfer ---
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5"
)
sd_pipe = sd_pipe.to(device)
sd_pipe.enable_attention_slicing()  # helps reduce memory usage

# --- Load T5-base or T5-small for grammar correction ---
model_name = "vennify/t5-base-grammar-correction"  # or "prithivida/grammar_error_correcter_v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

def style_transfer(input_image, prompt, strength=0.5, guidance=7.5):
    result = sd_pipe(
        prompt=prompt,
        image=input_image,
        strength=strength,
        guidance_scale=guidance
    ).images[0]
    return result

def correct_text(text):
    inputs = tokenizer("gec: " + text, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_length=128)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# --- Gradio UI ---
with gr.Blocks() as demo:
    gr.Markdown("## 🎨 Style Transfer + ✍️ Grammar Correction")

    with gr.Tab("Image Style Transfer"):
        img_in = gr.Image(type="pil")
        prompt = gr.Textbox(label="Style Prompt")
        img_out = gr.Image()
        btn1 = gr.Button("Transfer Style")
        btn1.click(style_transfer, inputs=[img_in, prompt], outputs=img_out)

    with gr.Tab("Text Correction"):
        txt_in = gr.Textbox(label="Enter text")
        txt_out = gr.Textbox(label="Corrected text")
        btn2 = gr.Button("Correct")
        btn2.click(correct_text, inputs=txt_in, outputs=txt_out)

demo.launch()