Spaces:
Runtime error
Runtime error
| from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts | |
| from htbuilder.units import percent, px | |
| from htbuilder.funcs import rgba, rgb | |
| import streamlit as st | |
| import os | |
| import sys | |
| import argparse | |
| import clip | |
| import numpy as np | |
| from PIL import Image | |
| from dalle.models import Dalle | |
| from dalle.utils.utils import set_seed, clip_score | |
| def link(link, text, **style): | |
| return a(_href=link, _target="_blank", style=styles(**style))(text) | |
| def layout(*args): | |
| style = """ | |
| <style> | |
| # MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| .stApp { bottom: 105px; } | |
| </style> | |
| """ | |
| style_div = styles( | |
| position="fixed", | |
| left=0, | |
| bottom=0, | |
| margin=px(0, 0, 0, 0), | |
| width=percent(100), | |
| color="black", | |
| text_align="center", | |
| height="auto", | |
| opacity=1 | |
| ) | |
| style_hr = styles( | |
| display="block", | |
| margin=px(8, 8, "auto", "auto"), | |
| border_style="inset", | |
| border_width=px(2) | |
| ) | |
| body = p() | |
| foot = div( | |
| style=style_div | |
| )( | |
| hr( | |
| style=style_hr | |
| ), | |
| body | |
| ) | |
| st.markdown(style, unsafe_allow_html=True) | |
| for arg in args: | |
| if isinstance(arg, str): | |
| body(arg) | |
| elif isinstance(arg, HtmlElement): | |
| body(arg) | |
| st.markdown(str(foot), unsafe_allow_html=True) | |
| def footer(): | |
| myargs = [ | |
| "Created by ", | |
| link("https://jonathanmalott.com", "Jonathan Malott"), | |
| br(), | |
| link("https://bridgingbarriers.utexas.edu/good-systems", "Good Systems"), | |
| " Grand Challenge", | |
| ", The University of Texas at Austin.", | |
| " Advised by Dr. Junfeng Jiao.", | |
| br(), | |
| br(), | |
| ] | |
| layout(*myargs) | |
| #footer() | |
| def generate(prompt,crazy,k): | |
| device = 'cpu' | |
| print("-2-") | |
| model = Dalle.from_pretrained('minDALL-E/1.3B') # This will automatically download the pretrained model. | |
| print("-3-") | |
| model.to(device=device) | |
| num_candidates = 1 | |
| images = [] | |
| set_seed(np.random.randint(0,10000)) | |
| # Sampling | |
| images = model.sampling(prompt=prompt, | |
| top_k=2048, | |
| top_p=None, | |
| softmax_temperature=crazy, | |
| num_candidates=num_candidates, | |
| device=device).cpu().numpy() | |
| images = np.transpose(images, (0, 2, 3, 1)) | |
| # CLIP Re-ranking | |
| model_clip, preprocess_clip = clip.load("ViT-B/32", device=device) | |
| model_clip.to(device=device) | |
| rank = clip_score(prompt=prompt, | |
| images=images, | |
| model_clip=model_clip, | |
| preprocess_clip=preprocess_clip, | |
| device=device) | |
| result = images[rank] | |
| item = {} | |
| item['prompt'] = prompt | |
| item['crazy'] = crazy | |
| item['k'] = k | |
| item['image'] = Image.fromarray((result*255).astype(np.uint8)) | |
| st.session_state.results.append(item) | |
| def drawGrid(): | |
| master = {} | |
| order = 0 | |
| #print(st.session_state.results) | |
| for r in st.session_state.results[::-1]: | |
| _txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k']) | |
| if(_txt not in master): | |
| master[_txt] = [r] | |
| order += 1 | |
| else: | |
| master[_txt].append(r) | |
| for m in master: | |
| #with placeholder.container(): | |
| txt = master[m][0]['prompt']+" (temperature:"+ str(master[m][0]['crazy']) + ", top k:" + str(master[m][0]['k']) + ")" | |
| st.subheader(txt) | |
| col1, col2, col3 = st.columns(3) | |
| for ix, item in enumerate(master[m]): | |
| if ix % 3 == 0: | |
| with col1: | |
| st.image(item["image"]) | |
| if ix % 3 == 1: | |
| with col2: | |
| st.image(item["image"]) | |
| if ix % 3 == 2: | |
| with col3: | |
| st.image(item["image"]) | |