Spaces:
Build error
Build error
| import streamlit as st | |
| import time | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoTokenizer, AutoModel, AutoConfig | |
| import torch | |
| from tqdm import tqdm | |
| import gan_cls_768 | |
| from torch.autograd import Variable | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def clean(txt): | |
| txt = txt.lower() | |
| txt = txt.strip() | |
| txt = txt.strip('.') | |
| return txt | |
| max_len = 76 | |
| def tokenize(tokenizer, txt): | |
| return tokenizer( | |
| txt, | |
| max_length=max_len, | |
| padding='max_length', | |
| truncation=True, | |
| return_offsets_mapping=False | |
| ) | |
| def encode(model, tokenizer, txt): | |
| txt = clean(txt) | |
| txt_tokenized = tokenize(tokenizer, txt) | |
| for k, v in txt_tokenized.items(): | |
| txt_tokenized[k] = torch.tensor(v, dtype=torch.long, device=device)[None] | |
| model.eval() | |
| with torch.no_grad(): | |
| encoded = model(**txt_tokenized) | |
| return encoded.last_hidden_state.squeeze()[0].cpu().numpy() | |
| def get_model_roberta(): | |
| model_name = 'roberta-base' | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| config=AutoConfig.from_pretrained(model_name, output_hidden_states=True)).to(device) | |
| return model, tokenizer | |
| def get_model_gan(): | |
| generator = torch.nn.DataParallel(gan_cls_768.generator().to(device)) | |
| generator.load_state_dict(torch.load("./gen_125.pth", map_location=torch.device('cpu'))) | |
| generator.eval() | |
| return generator | |
| def generate_image(text, n): | |
| model, tokenizer = get_model_roberta() | |
| generator = get_model_gan() | |
| embed = encode(model, tokenizer, text) | |
| embed2 = torch.FloatTensor(embed) | |
| embed2 = embed2.unsqueeze(0) | |
| right_embed = Variable(embed2.float()).to(device) | |
| l = [] | |
| for i in tqdm(range(n)): | |
| noise = Variable(torch.randn(1, 100)).to(device) | |
| noise = noise.view(noise.size(0), 100, 1, 1) | |
| fake_images = generator(right_embed, noise) | |
| for idx, image in enumerate(fake_images): | |
| im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy()) | |
| l.append(im) | |
| return l | |
| st.set_page_config( | |
| page_title="ImageGen", | |
| page_icon="🧊", | |
| layout="centered", | |
| initial_sidebar_state="expanded", | |
| ) | |
| hide_st_style = """ | |
| <style> | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| header {visibility: hidden;} | |
| </style> | |
| """ | |
| st.markdown(hide_st_style, unsafe_allow_html=True) | |
| examples = [ | |
| "this petal has gorgeous purple petals and a long green pedicel", | |
| "this petal has gorgeous green petals and a long green pedicel", | |
| "a couple thin, sharp, knife-like petals that have a sharp, purple, needle-like center.", | |
| "this flower has petals that are pink and bell shaped", | |
| "salmon colored round petals with veins of dark pink throughout all combined in the center with a pale yellow pistol and pollen tube.", | |
| "this flower features a prominent ovary covered with dozens of small stamens featuring thin white petals.", | |
| "delicated pink petals clumped on one green pedicel with small sepals.", | |
| "the flower has big yellow upright petals attached to a thick vine", | |
| "these bright flowers have many yellow strip petals and stamen.", | |
| "a large red flower with black dots and a very long stigmas.", | |
| "this vivid pink flower is composed of several blossoms with ruffled petals above and below a bulbous yellow-streaked center.", | |
| "this flower has petals that are yellow and has black lines", | |
| "the pink flower has bell shaped petal that is soft, smooth and enclosing stamen sticking out from the centre", | |
| "this flower has orange petals with many dark spots, white stamen, and dark anthers.", | |
| "this flower has petals that are white and has a yellow style", | |
| "his flower has petals that are orange and are very thin", | |
| "a flower with singular conical purple petal and large white pistil.", | |
| "the flower has bright yellow soft petals with yellow stamens.", | |
| "this flower has petals that are purple and have dark lines", | |
| "this purple flower has pointy short petals and green sepal.", | |
| "this flower has petals that are purple and has a yellow style", | |
| "the petals on this flower are orange with a purple pistil.", | |
| "a flower with no visible petals and purple pistils in the center.", | |
| "a star shaped flower with five white petals with purple lines running through them.", | |
| "the petals on this flower are bright yellow in color and there are two rows. the bottom layer lays flat, while the top layer is shaped like a bowl around the pistil.", | |
| "this flower features a purple stigma surrounded by pointed waxy orange petals.", | |
| ] | |
| def app(): | |
| st.title("Text to Flower") | |
| st.markdown( | |
| """ | |
| **Demo for Paper:** Synthesizing Realistic Images from Textual Descriptions: A Transformer-Based GAN Approach. | |
| Presented in *"International Conference on Next-Generation Computing, IoT and Machine Learning (NCIM 2023)"* | |
| """ | |
| ) | |
| se = st.selectbox("Select from example", examples) | |
| row1_col1, row1_col2 = st.columns([2, 3]) | |
| width = 950 | |
| height = 600 | |
| with row1_col1: | |
| caption = st.text_area("Write your flower description here:", se, height=120) | |
| backend = st.selectbox( | |
| "Select a Model", ["Convolutional GAN with RoBERTa", ], index=0 | |
| ) | |
| if st.button("Generate", type="primary"): | |
| with st.spinner("Generating Flower Images..."): | |
| # # gen all | |
| # for i in examples: | |
| # imgs = generate_image(i, 1) | |
| # st.markdown(i) | |
| # st.image(imgs[0]) | |
| imgs = generate_image(caption, 12) | |
| #ss = st.success("Scores predicted successfully!") | |
| with row1_col2: | |
| st.markdown("Generated Flower Images:") | |
| fig, ax = plt.subplots(nrows=3, ncols=4) | |
| ax = ax.flatten() | |
| for idx, ax in enumerate(ax): | |
| ax.imshow(imgs[idx]) | |
| ax.axis('off') | |
| fig.tight_layout() | |
| st.pyplot(fig) | |
| app() | |
| # # Display a footer with links and credits | |
| #st.markdown("---") | |
| #st.markdown("Back to [www.shamimahamed.com](https://www.shamimahamed.com/).") | |
| # #st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)") | |