Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import zipfile | |
| import tempfile | |
| import base64 | |
| from PIL import Image | |
| from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | |
| import pandas as pd | |
| from nltk.corpus import wordnet | |
| import spacy | |
| import io | |
| from spacy.cli import download | |
| # Download and load the spaCy model | |
| download("en_core_web_sm") | |
| nlp = spacy.load("en_core_web_sm") | |
| # Download NLTK WordNet data | |
| import nltk | |
| nltk.download('wordnet') | |
| nltk.download('omw-1.4') | |
| # Load the pre-trained model for image captioning | |
| model_name = "NourFakih/Vit-GPT2-COCO2017Flickr-85k-11" | |
| model = VisionEncoderDecoderModel.from_pretrained(model_name) | |
| feature_extractor = ViTImageProcessor.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Update the model config | |
| model.config.eos_token_id = tokenizer.eos_token_id | |
| model.config.decoder_start_token_id = tokenizer.bos_token_id | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| def generate_caption(image): | |
| pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
| output_ids = model.generate(pixel_values) | |
| caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return caption | |
| def get_synonyms(word): | |
| synonyms = set() | |
| for syn in wordnet.synsets(word): | |
| for lemma in syn.lemmas(): | |
| synonyms.add(lemma.name()) | |
| return synonyms | |
| def preprocess_query(query): | |
| doc = nlp(query) | |
| tokens = set() | |
| for token in doc: | |
| tokens.add(token.text.lower()) | |
| tokens.add(token.lemma_.lower()) | |
| tokens.update(get_synonyms(token.text.lower())) | |
| return tokens | |
| def search_captions(query, captions): | |
| query_tokens = preprocess_query(query) | |
| results = [] | |
| for path, caption in captions.items(): | |
| caption_tokens = preprocess_query(caption) | |
| if query_tokens & caption_tokens: | |
| results.append((path, caption)) | |
| return results | |
| st.title("Image Captioning Gallery") | |
| # Sidebar for search functionality | |
| with st.sidebar: | |
| query = st.text_input("Search images by caption:") | |
| # Options for input strategy | |
| input_option = st.selectbox("Select input method:", ["Folder Path", "Upload Images", "Upload ZIP"]) | |
| image_files = [] | |
| if input_option == "Folder Path": | |
| folder_path = st.text_input("Enter the folder path containing images:") | |
| if folder_path and os.path.isdir(folder_path): | |
| image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.lower().endswith(('png', 'jpg', 'jpeg'))] | |
| elif input_option == "Upload Images": | |
| uploaded_files = st.file_uploader("Upload image files", type=["png", "jpg", "jpeg"], accept_multiple_files=True) | |
| if uploaded_files: | |
| for uploaded_file in uploaded_files: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as temp_file: | |
| temp_file.write(uploaded_file.read()) | |
| image_files.append(temp_file.name) | |
| elif input_option == "Upload ZIP": | |
| uploaded_zip = st.file_uploader("Upload a ZIP file containing images", type=["zip"]) | |
| if uploaded_zip: | |
| with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
| temp_file.write(uploaded_zip.read()) | |
| with zipfile.ZipFile(temp_file.name, 'r') as zip_ref: | |
| zip_ref.extractall("/tmp/images") | |
| image_files = [os.path.join("/tmp/images", f) for f in zip_ref.namelist() if f.lower().endswith(('png', 'jpg', 'jpeg'))] | |
| captions = {} | |
| if st.button("Generate Captions", key='generate_captions'): | |
| for image_file in image_files: | |
| try: | |
| image = Image.open(image_file) | |
| caption = generate_caption(image) | |
| captions[image_file] = caption | |
| except Exception as e: | |
| st.error(f"Error processing {image_file}: {e}") | |
| # Display images in a grid | |
| st.subheader("Images and Captions:") | |
| cols = st.columns(4) | |
| idx = 0 | |
| for image_path, caption in captions.items(): | |
| col = cols[idx % 4] | |
| with col: | |
| try: | |
| with open(image_path, "rb") as img_file: | |
| img_bytes = img_file.read() | |
| encoded_image = base64.b64encode(img_bytes).decode() | |
| st.markdown( | |
| f""" | |
| <div style='text-align: center;'> | |
| <img src='data:image/jpeg;base64,{encoded_image}' width='100%'> | |
| <p>{caption}</p> | |
| <p style='font-size: small; font-style: italic;'>{image_path}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(f"Error displaying {image_path}: {e}") | |
| idx += 1 | |
| if query: | |
| results = search_captions(query, captions) | |
| st.write("Search Results:") | |
| cols = st.columns(4) | |
| idx = 0 | |
| for image_path, caption in results: | |
| col = cols[idx % 4] | |
| with col: | |
| try: | |
| with open(image_path, "rb") as img_file: | |
| img_bytes = img_file.read() | |
| encoded_image = base64.b64encode(img_bytes).decode() | |
| st.markdown( | |
| f""" | |
| <div style='text-align: center;'> | |
| <img src='data:image/jpeg;base64,{encoded_image}' width='100%'> | |
| <p>{caption}</p> | |
| <p style='font-size: small; font-style: italic;'>{image_path}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(f"Error displaying search result {image_path}: {e}") | |
| idx += 1 | |
| # Save captions to Excel and provide a download button | |
| df = pd.DataFrame(list(captions.items()), columns=['Image', 'Caption']) | |
| excel_file = io.BytesIO() | |
| df.to_excel(excel_file, index=False) | |
| excel_file.seek(0) | |
| st.download_button(label="Download captions as Excel", | |
| data=excel_file, | |
| file_name="captions.xlsx", | |
| mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") | |