import gradio as gr from src.rep_api import sam_segment,replicate_zest from src.utils import open_image_from_url,image_to_base64,cut_alpha_from_image from src.utils import convert_to_pil def sam_zest(image,prompt,negative_prompt,material_img): outputs=sam_segment(image,prompt,negative_prompt) mask = open_image_from_url(outputs["inverted_mask"]) mask_ = mask.convert("L").point(lambda x: 255 - x) img=image.copy() img=convert_to_pil(img) img.putalpha(mask_) output=replicate_zest(cut_alpha_from_image(img),material_img=material_img) zest_img=open_image_from_url(output) zest_img=zest_img.resize((img.size)) zest_img.putalpha(mask_) image.paste(zest_img,(0,0),mask_) return image def sam_zest_tab(): with gr.TabItem("Zest Segment",id="zest_seg"): with gr.Row(): with gr.Column(): image=gr.Image(label="Base Image") seg_prompt=gr.Textbox(label="Segment area") seg_negative=gr.Textbox(label="Dont seg area") zest_image=gr.Image(label="Zest image") zest_examples= gr.Examples( examples=[ "https://replicate.delivery/pbxt/Kl23gJODaW7EuxrDzBG9dcgqRdMaYSWmBQ9UexnwPiL7AnIr/3.jpg", "https://replicate.delivery/pbxt/Kl2WefehduxwWcQc5OrrBH6AkojQ6OqyQSKBvBLrroSpEBim/f2f0488a-180e-4d7e-9907-f26f92ac5f16.jpg", "https://replicate.delivery/pbxt/Kl2VlUibviSP8Kq5ULLJmMOWorog1YFu0zTreqhqX97c62ku/572a1fc9-a114-4d5b-8c7c-85aa5648c7b4.jpg", "https://replicate.delivery/pbxt/Kl2VCw1UVIJsYw9r8iqSYUMm65ePJhfYOLNolOE8CwxfRjX2/28481ff0-0829-42af-a658-fb96be2abb3d.jpg", "Test_images/pattern_1.png", "Test_images/pattern_2.jpg", "Test_images/pattern_3.jpg", "Test_images/pattern_4.jpg"], inputs=[zest_image]) with gr.Column(): gen_zest=gr.Button("Add ZEST") output_img=gr.Image(label="output") gen_zest.click(sam_zest,inputs=[image,seg_prompt,seg_negative,zest_image],outputs=[output_img]) return image,output_img