| ############################################################################################################################# | |
| # Filename : app.py | |
| # Description: A Streamlit application to utilize five models back to back | |
| # Models used: | |
| # 1. Visual Question Answering (VQA). | |
| # 2. Fill-Mask. | |
| # 3. Text2text Generation. | |
| # 4. Text Generation. | |
| # 5. Topic. | |
| # Author : Georgios Ioannou | |
| # | |
| # Copyright © 2024 by Georgios Ioannou | |
| ############################################################################################################################# | |
| # Import libraries. | |
| import streamlit as st # Build the GUI of the application. | |
| import torch # Load Salesforce/blip model(s) on GPU. | |
| from bertopic import BERTopic # Topic model inference. | |
| from PIL import Image # Open and identify a given image file. | |
| from transformers import ( | |
| pipeline, | |
| BlipProcessor, | |
| BlipForQuestionAnswering, | |
| ) # VQA model inference. | |
| ############################################################################################################################# | |
| # Function to apply local CSS. | |
| def local_css(file_name): | |
| with open(file_name) as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| ############################################################################################################################# | |
| # Model 1. | |
| # Model 1 gets input from the user. | |
| # User -> Model 1 | |
| # Load the Visual Question Answering (VQA) model directly. | |
| # Using transformers. | |
| def load_model_blip(): | |
| blip_processor_base = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| blip_model_base = BlipForQuestionAnswering.from_pretrained( | |
| "Salesforce/blip-vqa-base" | |
| ) | |
| # Backup model. | |
| # blip_processor_large = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
| # blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
| # return blip_processor_large, blip_model_large | |
| return blip_processor_base, blip_model_base | |
| # General function for any Salesforce/blip model(s). | |
| # VQA model. | |
| def generate_answer_blip(processor, model, image, question): | |
| # Prepare image + question. | |
| inputs = processor(images=image, text=question, return_tensors="pt") | |
| generated_ids = model.generate(**inputs, max_length=50) | |
| generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| return generated_answer | |
| # Generate answer from the Salesforce/blip model(s). | |
| # VQA model. | |
| def generate_answer(image, question): | |
| answer_blip_base = generate_answer_blip( | |
| processor=blip_processor_base, | |
| model=blip_model_base, | |
| image=image, | |
| question=question, | |
| ) | |
| # answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question) | |
| # return answer_blip_large | |
| return answer_blip_base | |
| ############################################################################################################################# | |
| # Model 2. | |
| # Model 2 gets input from Model 1. | |
| # User -> Model 1 -> Model 2 | |
| def load_model_fill_mask(): | |
| return pipeline(task="fill-mask", model="bert-base-uncased") | |
| ############################################################################################################################# | |
| # Model 3. | |
| # Model 3 gets input from Model 2. | |
| # User -> Model 1 -> Model 2 -> Model 3 | |
| def load_model_text2text_generation(): | |
| return pipeline( | |
| task="text2text-generation", model="facebook/blenderbot-400M-distill" | |
| ) | |
| ############################################################################################################################# | |
| # Model 4. | |
| # Model 4 gets input from Model 3. | |
| # User -> Model 1 -> Model 2 -> Model 3 -> Model 4 | |
| def load_model_fill_text_generation(): | |
| return pipeline(task="text-generation", model="gpt2") | |
| ############################################################################################################################# | |
| # Model 5. | |
| # Model 5 gets input from Model 4. | |
| # User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5 | |
| def load_model_bertopic1(): | |
| return BERTopic.load(path="davanstrien/chat_topics") | |
| def load_model_bertopic2(): | |
| return BERTopic.load(path="MaartenGr/BERTopic_ArXiv") | |
| ############################################################################################################################# | |
| # Page title and favicon. | |
| st.set_page_config(page_title="Visual Question Answering", page_icon="❓") | |
| ############################################################################################################################# | |
| # Load the Salesforce/blip model directly. | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| # elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| # device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| blip_processor_base, blip_model_base = load_model_blip() | |
| blip_model_base.to(device) | |
| ############################################################################################################################# | |
| # Main function to create the Streamlit web application. | |
| # | |
| # 5 MODEL INFERENCES. | |
| # User Input = Image + Question About The Image. | |
| # User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5 | |
| def main(): | |
| try: | |
| ##################################################################################################################### | |
| # Load CSS. | |
| local_css("styles/style.css") | |
| ##################################################################################################################### | |
| # Title. | |
| title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem"> | |
| Georgios Ioannou's Visual Question Answering</h1>""" | |
| st.markdown(title, unsafe_allow_html=True) | |
| # st.title("ChefBot - Automated Recipe Assistant") | |
| ##################################################################################################################### | |
| # Subtitle. | |
| subtitle = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem"> | |
| CUNY Tech Prep Tutorial 4</h2>""" | |
| st.markdown(subtitle, unsafe_allow_html=True) | |
| ##################################################################################################################### | |
| # Image. | |
| image = "./ctp.png" | |
| left_co, cent_co, last_co = st.columns(3) | |
| with cent_co: | |
| st.image(image=image) | |
| ##################################################################################################################### | |
| # User input (Image). | |
| image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if image is not None: | |
| bytes_data = image.getvalue() | |
| with open(image.name, "wb") as file: | |
| file.write(bytes_data) | |
| st.image(image, caption="Uploaded Image.", use_column_width=True) | |
| raw_image = Image.open(image.name).convert("RGB") | |
| # User input (Question). | |
| question = st.text_input("What's your question?") | |
| ############################################################################################################# | |
| if question != "": | |
| # Model 1. | |
| with st.spinner( | |
| text="VQA inference..." | |
| ): # Spinner to keep the application interactive. | |
| # Model inference. | |
| answer = generate_answer(raw_image, question)[0] | |
| st.success(f"VQA: {answer}") | |
| bbu_pipeline = load_model_fill_mask() | |
| text = ( | |
| "I love " + answer + " and I would like to know how to [MASK]." | |
| ) | |
| ######################################################################################################### | |
| # Model 2. | |
| with st.spinner( | |
| text="Fill-Mask inference..." | |
| ): # Spinner to keep the application interactive. | |
| # Model inference. | |
| bbu_pipeline_output = bbu_pipeline(text) | |
| bbu_output = bbu_pipeline_output[0]["sequence"] | |
| st.success(f"Fill-Mask: {bbu_output}") | |
| facebook_pipeline = load_model_text2text_generation() | |
| utterance = bbu_output | |
| ######################################################################################################### | |
| # Model 3. | |
| with st.spinner( | |
| text="Text2text Generation inference..." | |
| ): # Spinner to keep the application interactive. | |
| # Model inference. | |
| facebook_pipeline_output = facebook_pipeline(utterance) | |
| facebook_output = facebook_pipeline_output[0]["generated_text"] | |
| st.success(f"Text2text Generation: {facebook_output}") | |
| gpt2_pipeline = load_model_fill_text_generation() | |
| ######################################################################################################### | |
| # Model 4. | |
| with st.spinner( | |
| text="Fill Text Generation inference..." | |
| ): # Spinner to keep the application interactive. | |
| # Model inference. | |
| gpt2_pipeline_output = gpt2_pipeline(facebook_output) | |
| gpt2_output = gpt2_pipeline_output[0]["generated_text"] | |
| st.success(f"Fill Text Generation: {gpt2_output}") | |
| ######################################################################################################### | |
| # Model 5. | |
| topic_model_1 = load_model_bertopic1() | |
| topic, prob = topic_model_1.transform(gpt2_pipeline_output) | |
| topic_model_1_output = topic_model_1.get_topic_info(topic[0])[ | |
| "Representation" | |
| ][0] | |
| st.success( | |
| f"Topic(s) from davanstrien/chat_topics: {topic_model_1_output}" | |
| ) | |
| topic_model_2 = load_model_bertopic2() | |
| topic, prob = topic_model_2.transform(gpt2_pipeline_output) | |
| topic_model_2_output = topic_model_2.get_topic_info(topic[0])[ | |
| "Representation" | |
| ][0] | |
| st.success( | |
| f"Topic(s) from MaartenGr/BERTopic_ArXiv: {topic_model_1_output}" | |
| ) | |
| except Exception as e: | |
| # General exception/error handling. | |
| st.error(e) | |
| # GitHub repository of author. | |
| st.markdown( | |
| f""" | |
| <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our | |
| <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b> | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| ############################################################################################################################# | |
| if __name__ == "__main__": | |
| main() | |