Spaces:
Runtime error
Runtime error
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
| from diffusers import StableDiffusionInpaintPipeline,StableDiffusionPipeline | |
| from PIL import Image | |
| import requests | |
| import cv2 | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import io | |
| import requests | |
| from huggingface_hub import login | |
| import os | |
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| IPmodel_path = "runwayml/stable-diffusion-inpainting" | |
| IPpipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| IPmodel_path, | |
| revision="fp16", | |
| torch_dtype=torch.float16, | |
| use_auth_token= st.secrets["AUTH_TOKEN"] | |
| ).to(device) | |
| trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
| trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
| SDpipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=st.secrets["AUTH_TOKEN"]).to(device) | |
| def create_mask(image, prompt): | |
| inputs = processor(text=[prompt], images=[image], padding="max_length", return_tensors="pt") | |
| # predict | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| preds = outputs.logits | |
| filename = f"mask.png" | |
| plt.imsave(filename,torch.sigmoid(preds)) | |
| gray_image = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY) | |
| (thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY) | |
| # For debugging only: | |
| # cv2.imwrite(filename,bw_image) | |
| # fix color format | |
| cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB) | |
| mask = cv2.bitwise_not(bw_image) | |
| cv2.imwrite(filename, mask) | |
| return Image.open('mask.png') | |
| def generate_image(image, product_name, target_name): | |
| mask = create_mask(image, product_name) | |
| image = image.resize((512, 512)) | |
| mask = mask.resize((512,512)) | |
| guidance_scale=8 | |
| #guidance_scale=16 | |
| num_samples = 4 | |
| prompt = target_name | |
| generator = torch.Generator(device=device).manual_seed(22) # change the seed to get different results | |
| im = IPpipe( | |
| prompt=prompt, | |
| image=image, | |
| mask_image=mask, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| ).images | |
| return im | |
| def translate_sentence(article, source, target): | |
| if target == 'eng_Latn': | |
| return article | |
| translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source, tgt_lang=target) | |
| output = translator(article, max_length=400) | |
| output = output[0]['translation_text'] | |
| return output | |
| codes_as_string = '''Modern Standard Arabic arb_Arab | |
| Danish dan_Latn | |
| German deu_Latn | |
| Greek ell_Grek | |
| English eng_Latn | |
| Estonian est_Latn | |
| Finnish fin_Latn | |
| French fra_Latn | |
| Hebrew heb_Hebr | |
| Hindi hin_Deva | |
| Croatian hrv_Latn | |
| Hungarian hun_Latn | |
| Indonesian ind_Latn | |
| Icelandic isl_Latn | |
| Italian ita_Latn | |
| Japanese jpn_Jpan | |
| Korean kor_Hang | |
| Luxembourgish ltz_Latn | |
| Macedonian mkd_Cyrl | |
| Maltese mlt_Latn | |
| Dutch nld_Latn | |
| Norwegian Bokmål nob_Latn | |
| Polish pol_Latn | |
| Portuguese por_Latn | |
| Russian rus_Cyrl | |
| Slovak slk_Latn | |
| Slovenian slv_Latn | |
| Spanish spa_Latn | |
| Serbian srp_Cyrl | |
| Swedish swe_Latn | |
| Thai tha_Thai | |
| Turkish tur_Latn | |
| Ukrainian ukr_Cyrl | |
| Vietnamese vie_Latn | |
| Chinese (Simplified) zho_Hans''' | |
| codes_as_string = codes_as_string.split('\n') | |
| flores_codes = {} | |
| for code in codes_as_string: | |
| lang, lang_code = code.split('\t') | |
| flores_codes[lang] = lang_code | |
| import gradio as gr | |
| import gc | |
| gc.collect() | |
| image_label = 'Please upload the image (optional)' | |
| extract_label = 'Specify what need to be extracted from the above image' | |
| prompt_label = 'Specify the description of image to be generated' | |
| button_label = "Proceed" | |
| output_label = "Generations" | |
| shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long'] | |
| shot_label = 'Choose the shot type' | |
| style_services = ['polaroid', 'monochrome', 'long exposure','color splash', 'Tilt shift'] | |
| style_label = 'Choose the style type' | |
| lighting_services = ['soft', 'ambivalent', 'ring','sun', 'cinematic'] | |
| lighting_label = 'Choose the lighting type' | |
| context_services = ['indoor', 'outdoor', 'at night','in the park', 'in the beach','studio'] | |
| context_label = 'Choose the context' | |
| lens_services = ['wide angle', 'telephoto', '24 mm','EF 70mm', 'Bokeh'] | |
| lens_label = 'Choose the lens type' | |
| device_services = ['iphone', 'CCTV', 'Nikon ZFX','Canon', 'Gopro'] | |
| device_label = 'Choose the device type' | |
| def change_lang(choice): | |
| global lang_choice | |
| lang_choice = choice | |
| new_image_label = translate_sentence(image_label, "english", choice) | |
| return [gr.update(visible=True, label=translate_sentence(image_label, flores_codes["English"],flores_codes[choice])), | |
| gr.update(visible=True, label=translate_sentence(extract_label, flores_codes["English"],flores_codes[choice])), | |
| gr.update(visible=True, label=translate_sentence(prompt_label, flores_codes["English"],flores_codes[choice])), | |
| gr.update(visible=True, value=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])), | |
| gr.update(visible=True, label=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])), | |
| ] | |
| def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ): | |
| if shot_radio != '': | |
| prompt_text += ","+shot_radio | |
| if style_radio != '': | |
| prompt_text += ","+style_radio | |
| if lighting_radio != '': | |
| prompt_text += ","+lighting_radio | |
| if context_radio != '': | |
| prompt_text += ","+ context_radio | |
| if lens_radio != '': | |
| prompt_text += ","+ lens_radio | |
| if device_radio != '': | |
| prompt_text += ","+ device_radio | |
| return prompt_text | |
| def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio): | |
| if extract_text == "" or input_file == "": | |
| translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"]) | |
| translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) | |
| print(translated_prompt) | |
| output = SDpipe(translated_prompt, height=512, width=512, num_images_per_prompt=4) | |
| return output.images | |
| elif extract_text != "" and input_file != "" and prompt_text !='': | |
| translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"]) | |
| translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio) | |
| print(translated_prompt) | |
| translated_extract = translate_sentence(extract_text, flores_codes[lang_choice], flores_codes["English"]) | |
| print(translated_extract) | |
| output = generate_image(Image.fromarray(input_file), translated_extract, translated_prompt) | |
| return output | |
| else: | |
| raise gr.Error("Please fill all details for guided image or atleast promt for free image rendition !") | |
| with gr.Blocks() as demo: | |
| lang_option = gr.Dropdown(list(flores_codes.keys()), default='English', label='Please Select your Language') | |
| with gr.Row(): | |
| input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512)) | |
| extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = True) | |
| prompt_text = gr.Textbox(label= prompt_label, lines=1, interactive = True, visible = True) | |
| with gr.Accordion("Advanced Options", open=False): | |
| shot_radio = gr.Radio(shot_services , label=shot_label, ) | |
| style_radio = gr.Radio(style_services , label=style_label) | |
| lighting_radio = gr.Radio(lighting_services , label=lighting_label) | |
| context_radio = gr.Radio(context_services , label=context_label) | |
| lens_radio = gr.Radio(lens_services , label=lens_label) | |
| device_radio = gr.Radio(device_services , label=device_label) | |
| button = gr.Button(value = button_label , visible = False) | |
| with gr.Row(): | |
| output_gallery = gr.Gallery(label = output_label, visible= False) | |
| lang_option.change(fn=change_lang, inputs=lang_option, outputs=[input_file, extract_text, prompt_text, button, output_gallery]) | |
| button.click( proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery]) | |
| demo.launch(debug=True) |