Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import random | |
| import sys | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| # from dotenv import load_dotenv | |
| from huggingface_hub import notebook_login | |
| from PIL import Image | |
| # ローカル実行用 .envファイルから環境変数読み込み | |
| # load_dotenv(".env") | |
| # ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN") | |
| # Hugging SpaceのSecret Repoから環境変数読み取り | |
| ACCESS_TOKEN = st.secrets["ACCESS_TOKEN"] | |
| sys.path.append("./") | |
| from simulation import * | |
| # シード値の固定 | |
| SEED = 42 | |
| np.random.seed(seed=SEED) | |
| random.seed(SEED) | |
| def main(): | |
| # 生息地を表すワード | |
| HABITAT_WORDS = " Alien from Mars" | |
| # パラメーター | |
| GENOMS_SIZE = 4 # 遺伝配列 0, 1 のどちらかを要素とした配列のサイズ | |
| TOUNAMENT_NUM = 10 # トーナメント方式で競わせる数 | |
| CROSSOVER_PB = 0.8 # cross over(交差) する確率 | |
| MUTATION_PB = 0.5 # mutation(突然変異)する確率 | |
| # グローバル変数 | |
| global best | |
| POPURATIONS = st.slider( | |
| label="人口数", | |
| min_value=3, | |
| max_value=3000, | |
| value=500, | |
| ) | |
| NUM_GENERATION = st.slider( | |
| label="世代数", | |
| min_value=10, | |
| max_value=10000, | |
| value=1000, | |
| ) | |
| # キーワード候補 | |
| word_dict = { | |
| "body_size": ["Fingertip sized", "Palm sized", "", "Tall", "Giant"], | |
| "body_hair": ["Bald", "Smooth", "", "Furry", "Very Furry"], | |
| "herd_num": ["Lone", "Pair", "", "Herd of", "Swarm of"], | |
| "eating": ["No teeth", "Herbivorous", "Omnivorous", "Carnivorous", "Fang"], | |
| "body_color": [ | |
| "Lightest skin", | |
| "Lighter skin", | |
| "", | |
| "Darker skin", | |
| "Darkest skin", | |
| ], | |
| "ferocity": ["Peaceful", "Gentle", "", "Ferocious", "Tyrannical"], | |
| } | |
| if st.button("実行", key="ga"): | |
| st.write("遺伝アルゴリズムの実行") | |
| progress_bar_ga = st.progress(0) | |
| # create first genetarion | |
| generation = create_generation(POPURATIONS, GENOMS_SIZE) | |
| progress_bar_ga.progress(50) | |
| # アルゴリズムの実行 | |
| best, worst = ga_solve( | |
| generation, | |
| NUM_GENERATION, | |
| POPURATIONS, | |
| TOUNAMENT_NUM, | |
| CROSSOVER_PB, | |
| MUTATION_PB, | |
| ) | |
| progress_bar_ga.progress(100) | |
| st.write("遺伝アルゴリズム処理の終了") | |
| st.write("画像生成の実行") | |
| progress_bar_image = st.progress(0) | |
| progress_bar_image.progress(0) | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "CompVis/stable-diffusion-v1-4", use_auth_token=ACCESS_TOKEN | |
| ) | |
| pipe.enable_attention_slicing() | |
| progress_bar_image.progress(7) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("used device is", device) | |
| pipe.to(device) | |
| # NSFWフィルターの回避 | |
| def null_safety(images, **kwargs): | |
| return images, False | |
| pipe.safety_checker = null_safety | |
| last_generation = NUM_GENERATION - 1 | |
| plt.figure(figsize=(8, 8)) | |
| plt.rcParams["font.size"] = 9 | |
| words = ( | |
| get_word_for_image_generate(word_dict, best, last_generation) | |
| + HABITAT_WORDS | |
| ) | |
| image = pipe(words)["images"][0] | |
| plt.title(f"{last_generation + 1}th\n{words}.") | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| plt.imshow(image) | |
| progress_bar_image.progress(100) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| im = Image.open(buf) | |
| numpy_image = np.array(im) | |
| st.image(numpy_image) | |
| if __name__ == "__main__": | |
| main() | |