Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| from sentence_transformers import SentenceTransformer | |
| import faiss # For similarity search | |
| from itertools import product | |
| # Load BLIP model for image captioning | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) | |
| blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| # Load Sentence Transformer for embeddings | |
| sbert_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| def process_images(uploaded_files): | |
| image_filenames = [] | |
| image_captions = [] | |
| image_data = {} # Local to this function | |
| # Process each uploaded image | |
| for idx, img_file in enumerate(uploaded_files): | |
| # Open image using PIL from the file path | |
| image = Image.open(img_file.name).convert('RGB') | |
| # Prepare the image for BLIP | |
| inputs = blip_processor(images=image, return_tensors="pt").to(device) | |
| # Generate caption using BLIP | |
| with torch.no_grad(): | |
| out = blip_model.generate(**inputs) | |
| caption = blip_processor.decode(out[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| # Store filename and caption in the dictionary | |
| filename = f'Image-{idx}' | |
| image_filenames.append(filename) | |
| image_captions.append(caption) | |
| image_data[filename] = { | |
| 'image': image, | |
| 'caption': caption | |
| } | |
| # Compute embeddings for image captions using Sentence Transformer | |
| if image_captions: | |
| caption_embeddings = sbert_model.encode(image_captions, convert_to_tensor=True) | |
| caption_embeddings_np = caption_embeddings.cpu().numpy() | |
| # Initialize FAISS vector store | |
| embedding_dim = caption_embeddings_np.shape[1] | |
| vector_store = faiss.IndexFlatIP(embedding_dim) # Inner Product for cosine similarity | |
| faiss.normalize_L2(caption_embeddings_np) | |
| vector_store.add(caption_embeddings_np) | |
| # Return image data, vector store, image filenames and captions for use in the next function | |
| return image_data, vector_store, image_filenames, image_captions, "Images processed successfully!" | |
| else: | |
| return None, None, None, None, "No images were processed." | |
| def recommend_outfits(user_query, image_data, vector_store, image_filenames, image_captions): | |
| if vector_store is None or len(image_data) == 0: | |
| return [], "Please upload images of your clothing first." | |
| # Encode user query into the same embedding space | |
| query_embedding = sbert_model.encode([user_query], convert_to_tensor=True) | |
| query_embedding_np = query_embedding.cpu().numpy() | |
| faiss.normalize_L2(query_embedding_np) | |
| # Perform similarity search in the FAISS index | |
| k = min(10, len(image_data)) # Retrieve top-k results | |
| distances, indices = vector_store.search(query_embedding_np, k) | |
| # Get retrieved filenames and captions | |
| retrieved_filenames = [image_filenames[idx] for idx in indices[0]] | |
| retrieved_captions = [image_captions[idx] for idx in indices[0]] | |
| # Categorize retrieved items | |
| categories = { | |
| 'tops': ['shirt', 't-shirt', 'jacket', 'sweater', 'blouse', 'coat'], | |
| 'bottoms': ['jeans', 'pants', 'shorts', 'skirt', 'trousers', 'chino'], | |
| 'dresses': ['dress', 'gown'], | |
| 'footwear': ['shoes', 'sneakers', 'boots', 'heels'], | |
| 'accessories': ['hat', 'sunglasses', 'scarf', 'belt', 'bag'], | |
| } | |
| items_by_category = {cat: [] for cat in categories} | |
| for filename, caption in zip(retrieved_filenames, retrieved_captions): | |
| matched = False | |
| for category, keywords in categories.items(): | |
| if any(keyword in caption.lower() for keyword in keywords): | |
| items_by_category[category].append((filename, caption)) | |
| matched = True | |
| break | |
| if not matched: | |
| items_by_category.setdefault('others', []).append((filename, caption)) | |
| # Generate combinations | |
| combinations = [] | |
| if items_by_category['dresses']: | |
| # Outfits with dresses and footwear | |
| combinations = list(product(items_by_category['dresses'], items_by_category['footwear'])) | |
| else: | |
| # Tops and bottoms | |
| combinations = list(product(items_by_category['tops'], items_by_category['bottoms'])) | |
| # Optionally include footwear if available | |
| if items_by_category['footwear']: | |
| combinations = list(product(items_by_category['tops'], items_by_category['bottoms'], items_by_category['footwear'])) | |
| # Prepare output images and captions | |
| outputs = [] | |
| if combinations: | |
| for outfit in combinations[:3]: # Limit to top 3 recommendations | |
| images = [] | |
| captions = [] | |
| for item in outfit: | |
| filename, caption = item | |
| images.append(image_data[filename]['image']) | |
| captions.append(caption) | |
| outputs.append((images, captions)) | |
| else: | |
| return [], "Not enough items to generate outfit combinations based on the current categories and retrieved items." | |
| # Prepare the outputs for Gradio components | |
| output_images = [] | |
| output_texts = [] | |
| for images, captions in outputs: | |
| # Combine images horizontally | |
| widths, heights = zip(*(img.size for img in images)) | |
| total_width = sum(widths) | |
| max_height = max(heights) | |
| new_im = Image.new('RGB', (total_width, max_height)) | |
| x_offset = 0 | |
| for img in images: | |
| new_im.paste(img, (x_offset, 0)) | |
| x_offset += img.size[0] | |
| output_images.append(new_im) | |
| output_texts.append('\n'.join(captions)) | |
| return output_images, '\n\n'.join(output_texts) | |
| # Gradio Interface setup using Blocks | |
| def gradio_app(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# RAG-Based Outfit Recommendation System") | |
| gr.Markdown("Ever spend too much time in the morning or before going out trying to decide what to wear ? ") | |
| gr.Markdown("Well think no more just upload images of your clothing items, like a virtual wardrobe and describe in natural language what kind of outfit you need.") | |
| gr.Markdown("Example: I want something classy but with bright colors for a date night") | |
| gr.Markdown("The system works like your typical RAG, your clothing items are embedded in a vector space based on their captions generated by a VLM") | |
| gr.Markdown("Then we match your query to the best matching combinations of items from your wardrobe") | |
| image_data_state = gr.State() | |
| vector_store_state = gr.State() | |
| image_filenames_state = gr.State() | |
| image_captions_state = gr.State() | |
| with gr.Tab("Step 1: Upload Images"): | |
| with gr.Row(): | |
| image_input = gr.File(type="filepath", label="Upload Your Clothing Images", file_count="multiple") | |
| process_button = gr.Button("Process Images") | |
| output_message = gr.Textbox(label="Status") | |
| process_button.click( | |
| fn=process_images, | |
| inputs=image_input, | |
| outputs=[image_data_state, vector_store_state, image_filenames_state, image_captions_state, output_message] | |
| ) | |
| with gr.Tab("Step 2: Get Recommendations"): | |
| user_query = gr.Textbox(lines=2, placeholder="Enter your outfit preference (e.g., casual for a date)", label="Describe Your Outfit Request") | |
| recommend_button = gr.Button("Recommend Outfits") | |
| output_gallery = gr.Gallery(label="Recommended Outfit Combinations") | |
| output_descriptions = gr.Textbox(label="Item Descriptions") | |
| recommend_button.click( | |
| fn=recommend_outfits, | |
| inputs=[user_query, image_data_state, vector_store_state, image_filenames_state, image_captions_state], | |
| outputs=[output_gallery, output_descriptions] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = gradio_app() | |
| demo.launch(share=True) | |