| | import streamlit as st |
| | import streamlit.components.v1 as components |
| | from PIL import Image |
| | import requests |
| |
|
| | from predict import generate_text |
| | from model import load_model |
| |
|
| | from streamlit_image_select import image_select |
| |
|
| |
|
| | |
| | st.set_page_config(page_title="Caption Machine", page_icon="📸") |
| |
|
| | |
| |
|
| | model, image_transform, tokenizer = load_model() |
| |
|
| | if 'model' not in st.session_state: |
| | st.session_state['model'] = model |
| |
|
| | if 'image_transform' not in st.session_state: |
| | st.session_state['image_transform'] = image_transform |
| |
|
| | if 'tokenizer' not in st.session_state: |
| | st.session_state['tokenizer'] = tokenizer |
| |
|
| |
|
| |
|
| | |
| | st.write( |
| | """<style> |
| | [data-testid="column"] { |
| | width: calc(50% - 1rem); |
| | flex: 1 1 calc(50% - 1rem); |
| | min-width: calc(50% - 1rem); |
| | } |
| | |
| | .separator { |
| | display: flex; |
| | align-items: center; |
| | text-align: center; |
| | } |
| | |
| | .separator::before, |
| | .separator::after { |
| | content: ''; |
| | flex: 1; |
| | border-bottom: 1px solid #000; |
| | } |
| | |
| | .separator:not(:empty)::before { |
| | margin-right: .25em; |
| | } |
| | |
| | .separator:not(:empty)::after { |
| | margin-left: .25em; |
| | } |
| | |
| | </style>""", |
| | unsafe_allow_html=True, |
| | ) |
| |
|
| | |
| | st.title("Image Captioner") |
| | st.markdown( |
| | "This app utilizes OpenAI's [GPT-2](https://openai.com/research/better-language-models) and [CLIP](https://openai.com/research/clip) models to generate image captions. The model architecture was inspired by [ClipCap: CLIP Prefix for Image Captioning](https://arxiv.org/abs/2111.09734), which uses CLIP encoding as prefix and fine-tune GPT-2 model to generate the caption." |
| | ) |
| |
|
| |
|
| |
|
| | |
| | select_file = image_select( |
| | label="Select a photo:", |
| | images=[ |
| | "https://farm5.staticflickr.com/4084/5093294428_2f50d54acb_z.jpg", |
| | "https://farm8.staticflickr.com/7044/6855243647_cd204d079c_z.jpg", |
| | "http://farm4.staticflickr.com/3016/2650267987_f478c8d682_z.jpg", |
| | "https://farm8.staticflickr.com/7249/6913786280_c145ecc433_z.jpg", |
| | ], |
| | |
| | ) |
| |
|
| | st.markdown("<div class='separator'>Or</div>", unsafe_allow_html=True) |
| |
|
| |
|
| | upload_file = st.file_uploader("Upload an image:", type=['png','jpg','jpeg']) |
| |
|
| |
|
| | |
| | if upload_file or select_file: |
| |
|
| | img = None |
| |
|
| | if upload_file: |
| | img = Image.open(upload_file) |
| | |
| | elif select_file: |
| | |
| | img = Image.open(requests.get(select_file, stream=True).raw) |
| | |
| | |
| | st.image(img) |
| | |
| |
|
| | |
| | with st.spinner('Generating caption...'): |
| | caption = generate_text(st.session_state['model'], img, st.session_state['tokenizer'], st.session_state['image_transform']) |
| |
|
| | st.success(f"Result: {caption}") |
| | |
| |
|
| | |
| | with st.expander("See model architecture"): |
| | st.markdown( |
| | """ |
| | Steps: |
| | 1. Feed image into CLIP Image Encoder to get image embedding |
| | 2. image embedding into text embedding shape |
| | 3. Feed Text into GPT-2 Text Embedder to get a text embedding |
| | 4. Concatenate two embeddings and feed into GPT-2 Attention Layers |
| | """) |
| | |
| | st.write(" \nModel Architecture: ") |
| | model_img = Image.open('./model.png') |
| | st.image(model_img, width=450) |
| |
|
| |
|
| |
|