Jasmeet Singh commited on
Commit
dd427e4
·
verified ·
1 Parent(s): 2d4ad80

profanity-check

Browse files

added profanity check to filter inappropriate prompts inputs.

Files changed (1) hide show
  1. app.py +17 -2
app.py CHANGED
@@ -4,6 +4,7 @@ from PIL import Image
4
  from generationPipeline import generate
5
  from transformers import CLIPTokenizer
6
  from loadModel import preload_models_from_standard_weights
 
7
  import gradio as gr
8
 
9
 
@@ -16,13 +17,27 @@ tokenizer = CLIPTokenizer("vocab.json", merges_file="merges.txt")
16
  model_file = "weights-inkpen.ckpt"
17
  models = preload_models_from_standard_weights(model_file, Device)
18
 
 
 
 
 
 
 
 
19
 
20
  @spaces.GPU(duration=180)
21
  def generate_image(mode, prompt, strength, seed, n_inference_steps, input_image=None):
 
 
 
 
 
 
 
22
  if mode == "Text-to-Image":
23
  # Ignore the input image
24
  output_image = generate(
25
- prompt=prompt,
26
  uncond_prompt="",
27
  input_image=None,
28
  strength=strength,
@@ -39,7 +54,7 @@ def generate_image(mode, prompt, strength, seed, n_inference_steps, input_image=
39
  elif mode == "Image-to-Image" and input_image is not None:
40
  # Use the uploaded image
41
  output_image = generate(
42
- prompt=prompt,
43
  uncond_prompt="",
44
  input_image=input_image,
45
  strength=strength,
 
4
  from generationPipeline import generate
5
  from transformers import CLIPTokenizer
6
  from loadModel import preload_models_from_standard_weights
7
+ from profanity_check import predict # Import the profanity-check library
8
  import gradio as gr
9
 
10
 
 
17
  model_file = "weights-inkpen.ckpt"
18
  models = preload_models_from_standard_weights(model_file, Device)
19
 
20
+ ## profanity check on input prompt
21
+
22
+ def filter_prompt(prompt):
23
+ if predict([prompt])[0] == 1:
24
+ return "Inappropriate content detected. Please modify the input."
25
+ return prompt
26
+
27
 
28
  @spaces.GPU(duration=180)
29
  def generate_image(mode, prompt, strength, seed, n_inference_steps, input_image=None):
30
+
31
+ ##check prompt id there is anything inappropriate
32
+
33
+ filtered_prompt = filter_prompt(prompt)
34
+ if filtered_prompt == "Inappropriate content detected. Please modify the input.":
35
+ return filtered_prompt
36
+
37
  if mode == "Text-to-Image":
38
  # Ignore the input image
39
  output_image = generate(
40
+ prompt=filtered_prompt,
41
  uncond_prompt="",
42
  input_image=None,
43
  strength=strength,
 
54
  elif mode == "Image-to-Image" and input_image is not None:
55
  # Use the uploaded image
56
  output_image = generate(
57
+ prompt=filtered_prompt,
58
  uncond_prompt="",
59
  input_image=input_image,
60
  strength=strength,