NewUserID commited on
Commit
d9e73df
·
verified ·
1 Parent(s): edd03a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -17
app.py CHANGED
@@ -20,8 +20,11 @@ device="cpu"
20
  #prompt_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=YOUR_TOKEN)
21
  #prompt_pipe.to(device)
22
 
23
- img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=YOUR_TOKEN)
24
- img_pipe.safety_checker = lambda images, clip_input: (images, False)
 
 
 
25
  img_pipe.to(device)
26
 
27
  source_img = gr.Image(source="upload", type="filepath", label="init_img | 512*512 px")
@@ -38,32 +41,21 @@ def resize(value,img):
38
 
39
 
40
  def infer(source_img, prompt, guide, steps, seed, strength):
41
- generator = torch.Generator('cpu').manual_seed(seed)
42
 
43
- # Load and resize image (Gradio gives file path)
44
  source_image = Image.open(source_img).convert("RGB")
45
  source_image = source_image.resize((512, 512), Image.Resampling.LANCZOS)
46
 
47
- # Run the img2img pipeline
48
  images_list = img_pipe(
49
- [prompt], # prompt as list
50
- image=source_image, # fixed: image, not init_image
51
  strength=strength,
52
  guidance_scale=guide,
53
  num_inference_steps=steps,
54
  generator=generator
55
  )
56
 
57
- # Handle output (NSFW filter placeholder logic)
58
- images = []
59
- safe_image = Image.open("unsafe.png")
60
- for i, image in enumerate(images_list["images"]):
61
- if images_list["nsfw_content_detected"][i]:
62
- images.append(image) # replace with safe_image if NSFW
63
- else:
64
- images.append(image)
65
-
66
- return images
67
 
68
  print("Great sylvain ! Everything is working fine !")
69
 
 
20
  #prompt_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=YOUR_TOKEN)
21
  #prompt_pipe.to(device)
22
 
23
+ img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
24
+ "runwayml/stable-diffusion-v1-5",
25
+ use_auth_token=YOUR_TOKEN,
26
+ safety_checker=None, # ← disable safety checker
27
+ )
28
  img_pipe.to(device)
29
 
30
  source_img = gr.Image(source="upload", type="filepath", label="init_img | 512*512 px")
 
41
 
42
 
43
  def infer(source_img, prompt, guide, steps, seed, strength):
44
+ generator = torch.Generator("cpu").manual_seed(seed)
45
 
 
46
  source_image = Image.open(source_img).convert("RGB")
47
  source_image = source_image.resize((512, 512), Image.Resampling.LANCZOS)
48
 
 
49
  images_list = img_pipe(
50
+ [prompt],
51
+ image=source_image,
52
  strength=strength,
53
  guidance_scale=guide,
54
  num_inference_steps=steps,
55
  generator=generator
56
  )
57
 
58
+ return images_list["images"]
 
 
 
 
 
 
 
 
 
59
 
60
  print("Great sylvain ! Everything is working fine !")
61