| import streamlit as st |
|
|
| from utils import generation,load_model |
|
|
| |
| st.title("Gen of butterfly") |
| st.markdown("This is lightweight_gan") |
|
|
| |
|
|
| st.sidebar.subheader("Even a sloth can learn Machine Learning—one slow line of code at a time!") |
| st.sidebar.image("assets/logo.png", width=200) |
| st.sidebar.caption("https://wgcv.me") |
|
|
| |
| model_id="ceyda/butterfly_cropped_uniq1K_512" |
| model = load_model(model_id) |
| n_gen = 16 |
|
|
| def run(): |
| with st.spinner("Loading the model"): |
|
|
| ims = generation(model,batch_size=n_gen) |
| st.session_state["ims"] = ims |
|
|
| if("ims" not in st.session_state): |
| st.session_state["ims"] = None |
| run() |
|
|
|
|
| ims = st.session_state["ims"] |
| run_button = st.button("Gen AI butterfly", on_click=run,help="This would run the model") |
|
|
| if(ims is not None): |
| cols = st.columns(n_gen) |
| for j,im in enumerate(ims): |
| i = j % n_gen |
| cols[i].image(im, use_column_width=True) |
|
|