| import streamlit as st |
| import open_clip |
| import torch |
| from PIL import Image |
| import numpy as np |
| from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation |
| import chromadb |
| import logging |
| import io |
| import requests |
| from concurrent.futures import ThreadPoolExecutor |
| from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction |
| from chromadb.utils.data_loaders import ImageLoader |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class CustomFashionEmbeddingFunction: |
| def __init__(self): |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = self.model.to(self.device) |
| |
| def __call__(self, input): |
| try: |
| |
| processed_images = [] |
| for img in input: |
| if isinstance(img, (str, bytes)): |
| if isinstance(img, str): |
| response = requests.get(img) |
| img = Image.open(io.BytesIO(response.content)).convert('RGB') |
| else: |
| img = Image.open(io.BytesIO(img)).convert('RGB') |
| elif isinstance(img, np.ndarray): |
| img = Image.fromarray(img.astype('uint8')).convert('RGB') |
| |
| processed_img = self.preprocess(img).unsqueeze(0).to(self.device) |
| processed_images.append(processed_img) |
| |
| |
| batch = torch.cat(processed_images) |
| |
| |
| with torch.no_grad(): |
| clip_features = self.model.encode_image(batch) |
| clip_features = clip_features.cpu().numpy() |
| |
| |
| color_features_list = [] |
| for img in input: |
| if isinstance(img, (str, bytes)): |
| if isinstance(img, str): |
| response = requests.get(img) |
| img = Image.open(io.BytesIO(response.content)).convert('RGB') |
| else: |
| img = Image.open(io.BytesIO(img)).convert('RGB') |
| elif isinstance(img, np.ndarray): |
| img = Image.fromarray(img.astype('uint8')).convert('RGB') |
| |
| color_features = self.extract_color_histogram(img) |
| color_features_list.append(color_features) |
| |
| |
| combined_embeddings = [] |
| for clip_emb, color_feat in zip(clip_features, color_features_list): |
| |
| if clip_emb.shape[0] < 768: |
| padding = np.zeros(768 - clip_emb.shape[0]) |
| clip_emb = np.concatenate([clip_emb, padding]) |
| else: |
| clip_emb = clip_emb[:768] |
| |
| |
| color_features_expanded = np.repeat(color_feat, 32) |
| |
| |
| clip_emb = clip_emb / (np.linalg.norm(clip_emb) + 1e-8) |
| color_features_expanded = color_features_expanded / (np.linalg.norm(color_features_expanded) + 1e-8) |
| |
| |
| combined = clip_emb * 0.7 + color_features_expanded * 0.3 |
| combined = combined / (np.linalg.norm(combined) + 1e-8) |
| |
| combined_embeddings.append(combined) |
| |
| return np.array(combined_embeddings) |
| |
| except Exception as e: |
| logger.error(f"Error in embedding function: {e}") |
| raise |
|
|
| def extract_color_histogram(self, image): |
| """Extract color histogram from the image""" |
| try: |
| if isinstance(image, (str, bytes)): |
| if isinstance(image, str): |
| response = requests.get(image) |
| image = Image.open(io.BytesIO(response.content)) |
| else: |
| image = Image.open(io.BytesIO(image)) |
| |
| if not isinstance(image, np.ndarray): |
| img_array = np.array(image) |
| else: |
| img_array = image |
| |
| |
| img_hsv = Image.fromarray(img_array.astype('uint8')).convert('HSV') |
| hsv_pixels = np.array(img_hsv) |
| |
| h_hist = np.histogram(hsv_pixels[:,:,0], bins=8, range=(0, 256))[0] |
| s_hist = np.histogram(hsv_pixels[:,:,1], bins=8, range=(0, 256))[0] |
| v_hist = np.histogram(hsv_pixels[:,:,2], bins=8, range=(0, 256))[0] |
| |
| |
| h_hist = h_hist / (h_hist.sum() + 1e-8) |
| s_hist = s_hist / (s_hist.sum() + 1e-8) |
| v_hist = v_hist / (v_hist.sum() + 1e-8) |
| |
| return np.concatenate([h_hist, s_hist, v_hist]) |
| except Exception as e: |
| logger.error(f"Color histogram extraction error: {e}") |
| return np.zeros(24) |
|
|
| |
| if 'image' not in st.session_state: |
| st.session_state.image = None |
| if 'detected_items' not in st.session_state: |
| st.session_state.detected_items = None |
| if 'selected_item_index' not in st.session_state: |
| st.session_state.selected_item_index = None |
| if 'upload_state' not in st.session_state: |
| st.session_state.upload_state = 'initial' |
| if 'search_clicked' not in st.session_state: |
| st.session_state.search_clicked = False |
|
|
| |
| @st.cache_resource |
| def load_segmentation_model(): |
| try: |
| model_name = "mattmdjaga/segformer_b2_clothes" |
| image_processor = AutoImageProcessor.from_pretrained(model_name) |
| model = AutoModelForSemanticSegmentation.from_pretrained(model_name) |
| |
| if torch.cuda.is_available(): |
| model = model.to('cuda') |
| |
| return model, image_processor |
| except Exception as e: |
| logger.error(f"Error loading segmentation model: {e}") |
| raise |
|
|
| |
| def setup_multimodal_collection(): |
| """๋ฉํฐ๋ชจ๋ฌ ์ปฌ๋ ์
์ค์ """ |
| try: |
| client = chromadb.PersistentClient(path="./fashion_multimodal_db") |
| embedding_function = CustomFashionEmbeddingFunction() |
| data_loader = ImageLoader() |
| |
| |
| try: |
| collection = client.get_collection( |
| name="fashion_multimodal_v2", |
| embedding_function=embedding_function, |
| data_loader=data_loader |
| ) |
| logger.info("Successfully connected to existing clothes_multimodal collection") |
| return collection |
| |
| except Exception as e: |
| logger.error(f"Error getting existing collection: {e}") |
| |
| collection = client.create_collection( |
| name="clothes_multimodal", |
| embedding_function=embedding_function, |
| data_loader=data_loader |
| ) |
| logger.info("Created new clothes_multimodal collection") |
| return collection |
| |
| except Exception as e: |
| logger.error(f"Error setting up multimodal collection: {e}") |
| raise |
|
|
| def process_segmentation(image): |
| """Segmentation processing""" |
| try: |
| model, image_processor = load_segmentation_model() |
| |
| |
| inputs = image_processor(image, return_tensors="pt") |
| |
| if torch.cuda.is_available(): |
| inputs = {k: v.to('cuda') for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| |
| logits = outputs.logits.cpu() |
| upsampled_logits = torch.nn.functional.interpolate( |
| logits, |
| size=image.size[::-1], |
| mode="bilinear", |
| align_corners=False, |
| ) |
| |
| |
| seg_masks = upsampled_logits.argmax(dim=1).numpy() |
| |
| processed_items = [] |
| unique_labels = np.unique(seg_masks) |
| |
| for label_idx in unique_labels: |
| if label_idx == 0: |
| continue |
| |
| mask = (seg_masks[0] == label_idx).astype(float) |
| |
| processed_segment = { |
| 'label': f"Item_{label_idx}", |
| 'score': 1.0, |
| 'mask': mask |
| } |
| |
| processed_items.append(processed_segment) |
| |
| logger.info(f"Successfully processed {len(processed_items)} segments") |
| return processed_items |
| |
| except Exception as e: |
| logger.error(f"Segmentation error: {str(e)}") |
| import traceback |
| logger.error(traceback.format_exc()) |
| return [] |
| |
| def search_similar_items(image, mask=None, top_k=10): |
| """๋ ๊ฐ์ ๋ฉํฐ๋ชจ๋ฌ ์ปฌ๋ ์
์์ ๊ฒ์ ์ํ""" |
| try: |
| client = chromadb.PersistentClient(path="./fashion_multimodal_db") |
| embedding_function = CustomFashionEmbeddingFunction() |
| data_loader = ImageLoader() |
| |
| |
| collections = [] |
| collection_names = ["fashion_multimodal", "fashion_multimodal_v2"] |
| |
| for name in collection_names: |
| try: |
| collection = client.get_collection( |
| name=name, |
| embedding_function=embedding_function, |
| data_loader=data_loader |
| ) |
| collections.append(collection) |
| logger.info(f"Successfully connected to {name} collection") |
| except Exception as e: |
| logger.error(f"Error getting collection {name}: {e}") |
| continue |
| |
| if not collections: |
| logger.error("No collections available for search") |
| return [] |
| |
| |
| if mask is not None: |
| mask_3d = np.stack([mask] * 3, axis=-1) |
| masked_image = np.array(image) * mask_3d |
| query_image = Image.fromarray(masked_image.astype(np.uint8)) |
| else: |
| query_image = image |
| |
| |
| all_results = [] |
| |
| for collection in collections: |
| try: |
| results = collection.query( |
| query_images=[np.array(query_image)], |
| n_results=top_k, |
| include=['metadatas', 'distances'] |
| ) |
| |
| if results and 'metadatas' in results: |
| for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): |
| |
| cosine_similarity = 1 - (distance ** 2 / 2) |
| similarity_score = ((cosine_similarity + 1) / 2) * 100 |
| |
| item_data = metadata.copy() |
| item_data['similarity_score'] = similarity_score |
| all_results.append(item_data) |
| |
| except Exception as e: |
| logger.error(f"Error searching in collection: {e}") |
| continue |
| |
| |
| |
| seen_urls = set() |
| unique_results = [] |
| |
| for item in sorted(all_results, key=lambda x: x['similarity_score'], reverse=True): |
| url = item.get('image_url', '') |
| if url not in seen_urls: |
| seen_urls.add(url) |
| unique_results.append(item) |
| |
| |
| if len(unique_results) >= top_k: |
| break |
| |
| return unique_results |
| |
| except Exception as e: |
| logger.error(f"Multimodal search error: {e}") |
| return [] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def show_similar_items(similar_items): |
| """Display similar items in a structured format with similarity scores""" |
| if not similar_items: |
| st.warning("No similar items found.") |
| return |
| |
| st.subheader("Similar Items:") |
| |
| items_per_row = 2 |
| for i in range(0, len(similar_items), items_per_row): |
| cols = st.columns(items_per_row) |
| for j, col in enumerate(cols): |
| if i + j < len(similar_items): |
| item = similar_items[i + j] |
| with col: |
| try: |
| if 'image_url' in item: |
| st.image(item['image_url'], use_column_width=True) |
| |
| st.markdown(f"**Similarity: {item['similarity_score']:.1f}%**") |
| |
| st.write(f"Brand: {item.get('brand', 'Unknown')}") |
| name = item.get('name', 'Unknown') |
| if len(name) > 50: |
| name = name[:47] + "..." |
| st.write(f"Name: {name}") |
| |
| price = item.get('price', 0) |
| if isinstance(price, (int, float)): |
| st.write(f"Price: {price:,}์") |
| else: |
| st.write(f"Price: {price}") |
| |
| if 'discount' in item and item['discount']: |
| st.write(f"Discount: {item['discount']}%") |
| if 'original_price' in item: |
| st.write(f"Original: {item['original_price']:,}์") |
| |
| st.divider() |
| |
| except Exception as e: |
| logger.error(f"Error displaying item: {e}") |
| st.error("Error displaying this item") |
|
|
| def process_search(image, mask, num_results): |
| """์ ์ฌ ์์ดํ
๊ฒ์ ์ฒ๋ฆฌ""" |
| try: |
| with st.spinner("Finding similar items..."): |
| similar_items = search_similar_items(image, mask, num_results) |
| |
| return similar_items |
| except Exception as e: |
| logger.error(f"Search processing error: {e}") |
| return None |
|
|
| def handle_file_upload(): |
| if st.session_state.uploaded_file is not None: |
| image = Image.open(st.session_state.uploaded_file).convert('RGB') |
| st.session_state.image = image |
| st.session_state.upload_state = 'image_uploaded' |
| st.rerun() |
|
|
| def handle_detection(): |
| if st.session_state.image is not None: |
| detected_items = process_segmentation(st.session_state.image) |
| st.session_state.detected_items = detected_items |
| st.session_state.upload_state = 'items_detected' |
| st.rerun() |
|
|
| def handle_search(): |
| st.session_state.search_clicked = True |
|
|
| def main(): |
| st.title("Fashion Search App") |
|
|
| |
| st.sidebar.title("Admin Controls") |
| if st.sidebar.checkbox("Show Admin Interface"): |
| if st.sidebar.button("Update Database (Multimodal)"): |
| with st.spinner("Updating database with multimodal support..."): |
| success = update_db_with_multimodal() |
| if success: |
| st.sidebar.success("Database updated successfully!") |
| else: |
| st.sidebar.error("Failed to update database") |
| st.divider() |
|
|
| |
| if st.session_state.upload_state == 'initial': |
| uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'], |
| key='uploaded_file', on_change=handle_file_upload) |
|
|
| |
| if st.session_state.image is not None: |
| st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True) |
| |
| if st.session_state.detected_items is None: |
| if st.button("Detect Items", key='detect_button', on_click=handle_detection): |
| pass |
| |
| |
| if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0: |
| cols = st.columns(2) |
| for idx, item in enumerate(st.session_state.detected_items): |
| with cols[idx % 2]: |
| try: |
| if item.get('mask') is not None: |
| masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2) |
| st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}") |
| |
| st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}") |
| score = item.get('score') |
| if score is not None and isinstance(score, (int, float)): |
| st.write(f"Confidence: {score*100:.1f}%") |
| else: |
| st.write("Confidence: N/A") |
| except Exception as e: |
| logger.error(f"Error displaying item {idx}: {str(e)}") |
| st.error(f"Error displaying item {idx}") |
| |
| valid_items = [i for i in range(len(st.session_state.detected_items)) |
| if st.session_state.detected_items[i].get('mask') is not None] |
| |
| if not valid_items: |
| st.warning("No valid items detected for search.") |
| return |
| |
| selected_idx = st.selectbox( |
| "Select item to search:", |
| valid_items, |
| format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}", |
| key='item_selector' |
| ) |
| |
| search_col1, search_col2 = st.columns([1, 2]) |
| with search_col1: |
| search_clicked = st.button("Search Similar Items", |
| key='search_button', |
| type="primary") |
| with search_col2: |
| num_results = st.slider("Number of results:", |
| min_value=1, |
| max_value=20, |
| value=5, |
| key='num_results') |
|
|
| if search_clicked or st.session_state.get('search_clicked', False): |
| st.session_state.search_clicked = True |
| selected_item = st.session_state.detected_items[selected_idx] |
| |
| if selected_item.get('mask') is None: |
| st.error("Selected item has no valid mask for search.") |
| return |
| |
| if 'search_results' not in st.session_state: |
| similar_items = process_search(st.session_state.image, |
| selected_item['mask'], |
| num_results) |
| st.session_state.search_results = similar_items |
| |
| if st.session_state.search_results: |
| show_similar_items(st.session_state.search_results) |
| else: |
| st.warning("No similar items found.") |
|
|
| |
| if st.button("Start New Search", key='new_search'): |
| for key in list(st.session_state.keys()): |
| del st.session_state[key] |
| st.rerun() |
|
|
| if __name__ == "__main__": |
| print('์์') |
| main() |