Geek7 commited on
Commit
9cc033b
·
verified ·
1 Parent(s): 052862e

Update edit_app.py

Browse files
Files changed (1) hide show
  1. edit_app.py +95 -0
edit_app.py CHANGED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import random
5
+ import torch
6
+ from PIL import Image, ImageOps
7
+ from diffusers import StableDiffusionInstructPix2PixPipeline
8
+ import streamlit as st
9
+
10
+ # Help text to be displayed in the app
11
+ help_text = """
12
+ If you're not getting what you want, there may be a few reasons:
13
+ 1. Is the image not changing enough? Your Image CFG weight may be too high. This value dictates how similar the output should be to the input.
14
+ 2. Conversely, is the image changing too much, such that the details in the original image aren't preserved? Try:
15
+ * Increasing the Image CFG weight, or
16
+ * Decreasing the Text CFG weight
17
+ 3. Try generating results with different random seeds by setting "Randomize Seed".
18
+ """
19
+
20
+ # Example instructions for users to test
21
+ example_instructions = [
22
+ "Make it a picasso painting",
23
+ "Turn it into an anime.",
24
+ "add dramatic lighting",
25
+ "Convert to black and white",
26
+ ]
27
+
28
+ # Load the model from Hugging Face
29
+ model_id = "timbrooks/instruct-pix2pix"
30
+ pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cuda")
31
+
32
+ # Main Streamlit App
33
+ def main():
34
+ st.title("InstructPix2Pix Image Editing")
35
+
36
+ st.markdown(help_text)
37
+
38
+ # Upload input image
39
+ uploaded_image = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
40
+
41
+ if uploaded_image is not None:
42
+ input_image = Image.open(uploaded_image).convert("RGB")
43
+ st.image(input_image, caption="Uploaded Image", width=512)
44
+ else:
45
+ st.warning("Please upload an image to proceed.")
46
+ return
47
+
48
+ # Choose or type in instruction for image edit
49
+ instruction = st.selectbox("Choose an instruction or type your own", example_instructions)
50
+ custom_instruction = st.text_input("Or type your custom instruction", "")
51
+ if custom_instruction:
52
+ instruction = custom_instruction
53
+
54
+ # Control parameters for generation
55
+ steps = st.slider("Steps", min_value=20, max_value=100, value=50, step=1)
56
+ randomize_seed = st.checkbox("Randomize Seed", value=True)
57
+ seed = st.number_input("Seed (Only used if Randomize Seed is disabled)", min_value=0, value=random.randint(0, 10000))
58
+
59
+ text_cfg_scale = st.slider("Text CFG", min_value=1.0, max_value=10.0, value=7.5, step=0.1)
60
+ image_cfg_scale = st.slider("Image CFG", min_value=0.5, max_value=2.0, value=1.5, step=0.1)
61
+
62
+ # Process button
63
+ if st.button("Generate Edited Image"):
64
+ with st.spinner("Generating the edited image..."):
65
+ result_image = generate(input_image, instruction, steps, randomize_seed, seed, text_cfg_scale, image_cfg_scale)
66
+ st.image(result_image, caption="Edited Image", width=512)
67
+
68
+ # Download the edited image
69
+ st.download_button("Download Image", data=result_image.tobytes(), file_name="edited_image.png", mime="image/png")
70
+
71
+ # Generate the edited image
72
+ def generate(input_image: Image.Image, instruction: str, steps: int, randomize_seed: bool, seed: int, text_cfg_scale: float, image_cfg_scale: float):
73
+ # Handle seed
74
+ if randomize_seed:
75
+ seed = random.randint(0, 100000)
76
+
77
+ # Resize the input image to 512x512 (Stable Diffusion requires square images)
78
+ width, height = input_image.size
79
+ factor = 512 / max(width, height)
80
+ width = int((width * factor) // 64) * 64
81
+ height = int((height * factor) // 64) * 64
82
+ input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
83
+
84
+ # Generate the edited image using the Pix2Pix pipeline
85
+ generator = torch.manual_seed(seed)
86
+ edited_image = pipe(
87
+ instruction, image=input_image,
88
+ guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale,
89
+ num_inference_steps=steps, generator=generator,
90
+ ).images[0]
91
+
92
+ return edited_image
93
+
94
+ if __name__ == "__main__":
95
+ main()