Spaces:
Running
Running
| import os | |
| import json | |
| import base64 | |
| import requests | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # MAX_LEN = 40 | |
| # STEP = 2 | |
| # x = np.arange(0, MAX_LEN, STEP) | |
| # token_counts = [0] * (MAX_LEN//STEP) | |
| # with open("prompts.json", 'r') as f: | |
| # prompts = json.load(f) | |
| # for prompt in prompts: | |
| # tokens = len(prompt.strip().split(' ')) | |
| # token_counts[min(tokens//STEP, MAX_LEN//STEP-1)] += 1 | |
| # plt.xticks(x, x+1) | |
| # plt.xlabel("token counts") | |
| # plt.bar(x, token_counts, width=1.3) | |
| # # plt.show() | |
| # plt.savefig("token_counts.png") | |
| ## Generate image prompts | |
| with open("prompts.json") as f: | |
| text_prompts = json.load(f) | |
| engine_id = "stable-diffusion-v1-6" | |
| api_host = os.getenv('API_HOST', 'https://api.stability.ai') | |
| api_key = os.getenv("STABILITY_API_KEY", "sk-ZvoFiXEbln6yh0hvSlm1K60WYcWFY5rmyW8a9FgoVBrKKP9N") | |
| if api_key is None: | |
| raise Exception("Missing Stability API key.") | |
| for idx, text in enumerate(text_prompts): | |
| if idx<=20: continue | |
| print(f"Start generate prompt[{idx}]: {text}") | |
| response = requests.post( | |
| f"{api_host}/v1/generation/{engine_id}/text-to-image", | |
| headers={ | |
| "Content-Type": "application/json", | |
| "Accept": "application/json", | |
| "Authorization": f"Bearer {api_key}" | |
| }, | |
| json={ | |
| "text_prompts": [ | |
| { | |
| "text": text.strip() | |
| } | |
| ], | |
| "cfg_scale": 7, | |
| "height": 1024, | |
| "width": 1024, | |
| "samples": 3, | |
| "steps": 30, | |
| }, | |
| ) | |
| if response.status_code != 200: | |
| # raise Exception("Non-200 response: " + str(response.text)) | |
| print(f"{idx} Failed!!! {str(response.text)}") | |
| continue | |
| print("Finished!") | |
| data = response.json() | |
| for i, image in enumerate(data["artifacts"]): | |
| img_path = f"./images/{idx}/v1_txt2img_{i}.png" | |
| os.makedirs(os.path.dirname(img_path), exist_ok=True) | |
| with open(img_path, "wb") as f: | |
| f.write(base64.b64decode(image["base64"])) |