File size: 3,872 Bytes
c77a06e
3e967a9
 
 
 
 
 
5dc62a4
3e967a9
 
 
e8773e4
3e967a9
 
 
 
 
 
 
 
dd427e4
5dc62a4
 
c77a06e
 
 
 
 
 
 
 
 
 
dd427e4
 
5dc62a4
dd427e4
 
 
3e967a9
589867c
b3c2131
dd427e4
 
 
 
 
 
 
b3c2131
c1a4711
b3c2131
dd427e4
c1a4711
b3c2131
c1a4711
 
 
b3c2131
 
 
 
 
 
 
 
c1a4711
 
b3c2131
dd427e4
b3c2131
 
 
 
 
 
 
 
 
 
 
 
 
c1a4711
 
 
3e967a9
 
c1a4711
3e967a9
 
 
b3c2131
5157e90
589867c
5157e90
589867c
5157e90
3e967a9
c65d468
3e967a9
b3c2131
3e967a9
 
c65d468
c1a4711
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import json  #to load json file for custom words for profanity check
import torch
import spaces
from PIL import Image
from generationPipeline import generate
from transformers import CLIPTokenizer
from loadModel import preload_models_from_standard_weights
from better_profanity import profanity  # Import the profanity-check library
import gradio as gr


Device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Using device: {Device}")


tokenizer = CLIPTokenizer("vocab.json", merges_file="merges.txt")
model_file = "weights-inkpen.ckpt"
models = preload_models_from_standard_weights(model_file, Device)

## profanity check on input prompt
profanity.load_censor_words() 

##fFunction to load custom words from json file
def load_custom_words(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
        return data["custom_words"]

#3 load and add custom words from the json file
custom_words = load_custom_words("profane.json")
profanity.add_censor_words(custom_words)


def filter_prompt(prompt):
    if profanity.contains_profanity(prompt):
        return "Inappropriate content detected. Please modify the input."
    return prompt


@spaces.GPU(duration=180)
def generate_image(mode, prompt, strength, seed, n_inference_steps, input_image=None):

    ##check prompt id there is anything inappropriate

    filtered_prompt = filter_prompt(prompt)
    if filtered_prompt == "Inappropriate content detected. Please modify the input.":
        return filtered_prompt
    
    if mode == "Text-to-Image":
        # Ignore the input image
        output_image = generate(
            prompt=filtered_prompt,
            uncond_prompt="",
            input_image=None,
            strength=strength,
            do_cfg=True,
            cfg_scale=8,
            sampler_name="ddpm",
            n_inference_steps=n_inference_steps,
            seed=seed,
            models=models,
            device=Device,
            idle_device="cpu",
            tokenizer=tokenizer,
        )
    elif mode == "Image-to-Image" and input_image is not None:
        # Use the uploaded image
        output_image = generate(
            prompt=filtered_prompt,
            uncond_prompt="",
            input_image=input_image,
            strength=strength,
            do_cfg=True,
            cfg_scale=8,
            sampler_name="ddpm",
            n_inference_steps=n_inference_steps,
            seed=seed,
            models=models,
            device=Device,
            idle_device="cpu",
            tokenizer=tokenizer,
        )
    else:
        return "Please upload an image for Image-to-Image mode."
    
    return Image.fromarray(output_image)

# Gradio interface with inputs for mode selection and inference steps
iface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Radio(choices=["Text-to-Image", "Image-to-Image"], label="Mode"),  # Toggle between modes
        gr.Textbox(label="Text Prompt"),  # Input text prompt
        gr.Slider(0, 1, step=0.01, label="Strength : (Note: set strength between 0.01 to 0.9, For Image-2-Image: Strength ~ 1 means that the output will be further from the input image. Strength ~ 0 means that the output will be closer to the input image.)"),  # Slider for strength
        gr.Number(label="Seed (for reproducibility)"),  # Seed for reproducibility
        gr.Slider(1, 200, step=10, label="Number of Inference Steps"),  # Slider for inference steps
        gr.Image(type='filepath', label='Input Image Only for Image-to-Image'),  # Input image without optional flag
    ],
    outputs=gr.Image(label="Generated Image"),  # Output generated image
    title="Stable Diffusion Image Generator",
    description="Generate images from text prompts or use image-to-image generation."
)

# Launch the Gradio interface with public sharing enabled
iface.launch(share=True, debug=True)