Spaces:
Runtime error
Runtime error
| # file stuff | |
| import os | |
| from io import BytesIO | |
| #image generation stuff | |
| from PIL import Image | |
| # gradio / hf / image gen stuff | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from google.cloud import aiplatform | |
| import vertexai | |
| from vertexai.preview.vision_models import ImageGenerationModel | |
| from vertexai import preview | |
| # GCP credentials stuff | |
| import json | |
| import pybase64 | |
| from google.oauth2 import service_account | |
| import google.auth | |
| load_dotenv() | |
| service_account_json = pybase64.b64decode(os.getenv("IMAGEN")) | |
| service_account_info = json.loads(service_account_json) | |
| credentials = service_account.Credentials.from_service_account_info(service_account_info) | |
| project="pdr-imagen" | |
| aiplatform.init(project=project, credentials=credentials) | |
| # enforce password is True if DO_ENFORCE_PW is set to "true" | |
| DO_ENFORCE_PW = os.getenv("DO_ENFORCE_PW") | |
| def trigger_max_gens(): | |
| gr.Warning("🖼️ Max Image Generations Reached! 🖼️") | |
| def generate_image(pw,prompt,model_name): | |
| if pw != os.getenv("PW") and DO_ENFORCE_PW == "true": | |
| raise gr.Error("Invalid password. Please try again.") | |
| try: | |
| model = ImageGenerationModel.from_pretrained(model_name) | |
| response = model.generate_images( | |
| prompt=prompt, | |
| number_of_images=1, | |
| ) | |
| image_bytes = response[0]._image_bytes | |
| image_url = Image.open(BytesIO(image_bytes)) | |
| except Exception as e: | |
| print(e) | |
| raise gr.Error(f"An error occurred while generating the image") | |
| return image_url | |
| custom_js = """ | |
| function customJS() { | |
| //Limit Image Generation | |
| const MAX_GENERATIONS = 10; | |
| const DO_ENFORCE_MAX_GENERATIONS = true; | |
| disableGenerateButton = function() { | |
| const btn = document.getElementById('btn_generate-images'); | |
| btn.disabled = true; | |
| btn.classList.add('not-visible'); | |
| } | |
| triggerMaxGenerationsToast = function() { | |
| const trigger_max_gens_btn = document.getElementById('trigger-max-gens-btn'); | |
| trigger_max_gens_btn.click(); | |
| } | |
| setCurrentGenerations = function() { | |
| if (!DO_ENFORCE_MAX_GENERATIONS) { | |
| return; | |
| } | |
| const curGenerations = localStorage.getItem('currentGenerations'); | |
| console.log(`${curGenerations} / ${MAX_GENERATIONS}`) | |
| if (curGenerations) { | |
| if (curGenerations >= MAX_GENERATIONS) { | |
| triggerMaxGenerationsToast(); | |
| disableGenerateButton(); | |
| } else { | |
| localStorage.setItem('currentGenerations', parseInt(curGenerations) + 1); | |
| } | |
| } else { | |
| localStorage.setItem('currentGenerations', 1); | |
| } | |
| } | |
| setCurrentGenerations(); | |
| document.getElementById('btn_generate-images').addEventListener('click', function() { | |
| setCurrentGenerations(); | |
| }); | |
| } | |
| """ | |
| with gr.Blocks(js=custom_js) as demo: | |
| gr.Markdown("# <center>Google Vertex Imagen Generator</center>") | |
| #password | |
| pw = gr.Textbox(label="Password", type="password", placeholder="Enter the password to unlock the service",visible=False if DO_ENFORCE_PW == "false" else True) | |
| gr.Markdown("Need access? Send a DM to @HeaversMike on Twitter or send me an email / Slack msg.",visible=False if DO_ENFORCE_PW == "false" else True) | |
| #instructions | |
| with gr.Accordion("Instructions & Tips",label="instructions",open=False): | |
| with gr.Row(): | |
| gr.Markdown("**Tips**: Use adjectives (size,color,mood), specify the visual style (realistic,cartoon,8-bit), explain the point of view (from above,first person,wide angle) ") | |
| #prompts | |
| with gr.Accordion("Prompt",label="Prompt",open=True): | |
| text = gr.Textbox(label="What do you want to create?", placeholder="Enter your text and then click on the \"Image Generate\" button") | |
| model = gr.Dropdown(choices=["imagegeneration@002", "imagegeneration@005"], label="Model", value="imagegeneration@005") | |
| with gr.Row(): | |
| btn = gr.Button("Generate Images", variant="primary", elem_id="btn_generate-images") | |
| #output | |
| with gr.Accordion("Image Output",label="Image Output",open=True): | |
| output_image = gr.Image(label="Image") | |
| with gr.Row(): | |
| trigger_max_gens_btn = gr.Button(value="Show Max Gens Reached",visible=False,elem_id="trigger-max-gens-btn") | |
| btn.click(fn=generate_image, inputs=[pw,text, model ], outputs=output_image, api_name=False) | |
| text.submit(fn=generate_image, inputs=[pw,text, model ], outputs=output_image, api_name="generate_image") # Generate an api endpoint in Gradio / HF | |
| #js-triggered functionality | |
| trigger_max_gens_btn.click(trigger_max_gens, None, None) | |
| demo.launch(share=False) |