Spaces:
Build error
Build error
| import streamlit as st | |
| import open_clip | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation | |
| import chromadb | |
| import logging | |
| import io | |
| import requests | |
| from concurrent.futures import ThreadPoolExecutor | |
| from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction | |
| from chromadb.utils.data_loaders import ImageLoader | |
| # ๋ก๊น ์ค์ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class CustomFashionEmbeddingFunction: | |
| def __init__(self): | |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = self.model.to(self.device) | |
| def __call__(self, input): | |
| try: | |
| # ์ ๋ ฅ์ด URL์ด๋ ๊ฒฝ๋ก์ธ ๊ฒฝ์ฐ ์ฒ๋ฆฌ | |
| processed_images = [] | |
| for img in input: | |
| if isinstance(img, (str, bytes)): | |
| if isinstance(img, str): | |
| response = requests.get(img) | |
| img = Image.open(io.BytesIO(response.content)).convert('RGB') | |
| else: | |
| img = Image.open(io.BytesIO(img)).convert('RGB') | |
| elif isinstance(img, np.ndarray): | |
| img = Image.fromarray(img.astype('uint8')).convert('RGB') | |
| processed_img = self.preprocess(img).unsqueeze(0).to(self.device) | |
| processed_images.append(processed_img) | |
| # ๋ฐฐ์น ์ฒ๋ฆฌ | |
| batch = torch.cat(processed_images) | |
| # CLIP ์๋ฒ ๋ฉ ์ถ์ถ | |
| with torch.no_grad(): | |
| clip_features = self.model.encode_image(batch) | |
| clip_features = clip_features.cpu().numpy() | |
| # ์์ ํน์ง ์ถ์ถ | |
| color_features_list = [] | |
| for img in input: | |
| if isinstance(img, (str, bytes)): | |
| if isinstance(img, str): | |
| response = requests.get(img) | |
| img = Image.open(io.BytesIO(response.content)).convert('RGB') | |
| else: | |
| img = Image.open(io.BytesIO(img)).convert('RGB') | |
| elif isinstance(img, np.ndarray): | |
| img = Image.fromarray(img.astype('uint8')).convert('RGB') | |
| color_features = self.extract_color_histogram(img) | |
| color_features_list.append(color_features) | |
| # ํน์ง ๊ฒฐํฉ | |
| combined_embeddings = [] | |
| for clip_emb, color_feat in zip(clip_features, color_features_list): | |
| # CLIP ์๋ฒ ๋ฉ์ 768์ฐจ์์ผ๋ก ํจ๋ฉ | |
| if clip_emb.shape[0] < 768: | |
| padding = np.zeros(768 - clip_emb.shape[0]) | |
| clip_emb = np.concatenate([clip_emb, padding]) | |
| else: | |
| clip_emb = clip_emb[:768] # 768์ฐจ์์ผ๋ก ์๋ฅด๊ธฐ | |
| # ์์ ํน์ง์ 768์ฐจ์์ผ๋ก ํ์ฅ | |
| color_features_expanded = np.repeat(color_feat, 32) # 24 * 32 = 768 | |
| # ์ ๊ทํ | |
| clip_emb = clip_emb / (np.linalg.norm(clip_emb) + 1e-8) | |
| color_features_expanded = color_features_expanded / (np.linalg.norm(color_features_expanded) + 1e-8) | |
| # ๊ฐ์ค์น ๊ฒฐํฉ | |
| combined = clip_emb * 0.7 + color_features_expanded * 0.3 | |
| combined = combined / (np.linalg.norm(combined) + 1e-8) | |
| combined_embeddings.append(combined) | |
| return np.array(combined_embeddings) | |
| except Exception as e: | |
| logger.error(f"Error in embedding function: {e}") | |
| raise | |
| def extract_color_histogram(self, image): | |
| """Extract color histogram from the image""" | |
| try: | |
| if isinstance(image, (str, bytes)): | |
| if isinstance(image, str): | |
| response = requests.get(image) | |
| image = Image.open(io.BytesIO(response.content)) | |
| else: | |
| image = Image.open(io.BytesIO(image)) | |
| if not isinstance(image, np.ndarray): | |
| img_array = np.array(image) | |
| else: | |
| img_array = image | |
| # HSV ๋ณํ ๋ฐ ํ์คํ ๊ทธ๋จ ๊ณ์ฐ | |
| img_hsv = Image.fromarray(img_array.astype('uint8')).convert('HSV') | |
| hsv_pixels = np.array(img_hsv) | |
| h_hist = np.histogram(hsv_pixels[:,:,0], bins=8, range=(0, 256))[0] | |
| s_hist = np.histogram(hsv_pixels[:,:,1], bins=8, range=(0, 256))[0] | |
| v_hist = np.histogram(hsv_pixels[:,:,2], bins=8, range=(0, 256))[0] | |
| # ์ ๊ทํ | |
| h_hist = h_hist / (h_hist.sum() + 1e-8) | |
| s_hist = s_hist / (s_hist.sum() + 1e-8) | |
| v_hist = v_hist / (v_hist.sum() + 1e-8) | |
| return np.concatenate([h_hist, s_hist, v_hist]) | |
| except Exception as e: | |
| logger.error(f"Color histogram extraction error: {e}") | |
| return np.zeros(24) | |
| # Initialize session state | |
| if 'image' not in st.session_state: | |
| st.session_state.image = None | |
| if 'detected_items' not in st.session_state: | |
| st.session_state.detected_items = None | |
| if 'selected_item_index' not in st.session_state: | |
| st.session_state.selected_item_index = None | |
| if 'upload_state' not in st.session_state: | |
| st.session_state.upload_state = 'initial' | |
| if 'search_clicked' not in st.session_state: | |
| st.session_state.search_clicked = False | |
| # Load segmentation model | |
| def load_segmentation_model(): | |
| try: | |
| model_name = "mattmdjaga/segformer_b2_clothes" | |
| image_processor = AutoImageProcessor.from_pretrained(model_name) | |
| model = AutoModelForSemanticSegmentation.from_pretrained(model_name) | |
| if torch.cuda.is_available(): | |
| model = model.to('cuda') | |
| return model, image_processor | |
| except Exception as e: | |
| logger.error(f"Error loading segmentation model: {e}") | |
| raise | |
| # ChromaDB ์ค์ | |
| def setup_multimodal_collection(): | |
| """๋ฉํฐ๋ชจ๋ฌ ์ปฌ๋ ์ ์ค์ """ | |
| try: | |
| client = chromadb.PersistentClient(path="./fashion_multimodal_db_original") | |
| embedding_function = CustomFashionEmbeddingFunction() | |
| data_loader = ImageLoader() | |
| # ๊ธฐ์กด ์ปฌ๋ ์ ๊ฐ์ ธ์ค๊ธฐ | |
| try: | |
| collection = client.get_collection( | |
| name="fashion_multimodal", | |
| embedding_function=embedding_function, | |
| data_loader=data_loader | |
| ) | |
| logger.info("Successfully connected to existing clothes_multimodal collection") | |
| return collection | |
| except Exception as e: | |
| logger.error(f"Error getting existing collection: {e}") | |
| # ์ปฌ๋ ์ ์ด ์๋ ๊ฒฝ์ฐ์๋ง ์๋ก ์์ฑ | |
| collection = client.create_collection( | |
| name="clothes_multimodal", | |
| embedding_function=embedding_function, | |
| data_loader=data_loader | |
| ) | |
| logger.info("Created new clothes_multimodal collection") | |
| return collection | |
| except Exception as e: | |
| logger.error(f"Error setting up multimodal collection: {e}") | |
| raise | |
| def process_segmentation(image): | |
| """Segmentation processing""" | |
| try: | |
| model, image_processor = load_segmentation_model() | |
| # ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ | |
| inputs = image_processor(image, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to('cuda') for k, v in inputs.items()} | |
| # ์ถ๋ก | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # ๋ก์ง ๋ฐ ํ์ฒ๋ฆฌ | |
| logits = outputs.logits.cpu() | |
| upsampled_logits = torch.nn.functional.interpolate( | |
| logits, | |
| size=image.size[::-1], # (height, width) | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| # ์ธ๊ทธ๋ฉํ ์ด์ ๋ง์คํฌ ์์ฑ | |
| seg_masks = upsampled_logits.argmax(dim=1).numpy() | |
| processed_items = [] | |
| unique_labels = np.unique(seg_masks) | |
| for label_idx in unique_labels: | |
| if label_idx == 0: # background | |
| continue | |
| mask = (seg_masks[0] == label_idx).astype(float) | |
| processed_segment = { | |
| 'label': f"Item_{label_idx}", # ๋ผ๋ฒจ ๋งคํ์ด ํ์ํ๋ค๋ฉด ์ฌ๊ธฐ์ ์ฒ๋ฆฌ | |
| 'score': 1.0, # confidence score ๊ณ์ฐ์ด ํ์ํ๋ค๋ฉด ์ถ๊ฐ | |
| 'mask': mask | |
| } | |
| processed_items.append(processed_segment) | |
| logger.info(f"Successfully processed {len(processed_items)} segments") | |
| return processed_items | |
| except Exception as e: | |
| logger.error(f"Segmentation error: {str(e)}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return [] | |
| def search_similar_items(image, mask=None, top_k=10): | |
| """๋ฉํฐ๋ชจ๋ฌ ๊ฒ์ ์ํ""" | |
| try: | |
| collection = setup_multimodal_collection() | |
| # ๋ง์คํฌ ์ ์ฉ | |
| if mask is not None: | |
| mask_3d = np.stack([mask] * 3, axis=-1) | |
| masked_image = np.array(image) * mask_3d | |
| query_image = Image.fromarray(masked_image.astype(np.uint8)) | |
| else: | |
| query_image = image | |
| # ๊ฒ์ ์ํ | |
| results = collection.query( | |
| query_images=[np.array(query_image)], | |
| n_results=top_k, | |
| include=['metadatas', 'distances'] | |
| ) | |
| if not results or 'metadatas' not in results: | |
| return [] | |
| similar_items = [] | |
| for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): | |
| # L2 ๊ฑฐ๋ฆฌ๋ฅผ ์ฝ์ฌ์ธ ์ ์ฌ๋๋ก ๋ณํ | |
| # ์ ๊ทํ๋ ๋ฒกํฐ ๊ฐ์ L2 ๊ฑฐ๋ฆฌ(d)์ ์ฝ์ฌ์ธ ์ ์ฌ๋(cos_sim) ๊ด๊ณ: | |
| # d^2 = 2(1 - cos_sim) | |
| # cos_sim = 1 - (d^2/2) | |
| cosine_similarity = 1 - (distance ** 2 / 2) | |
| # -1~1 ๋ฒ์์ ์ฝ์ฌ์ธ ์ ์ฌ๋๋ฅผ 0~100 ๋ฒ์๋ก ๋ณํ | |
| similarity_score = ((cosine_similarity + 1) / 2) * 100 | |
| item_data = metadata.copy() | |
| item_data['similarity_score'] = similarity_score | |
| similar_items.append(item_data) | |
| similar_items.sort(key=lambda x: x['similarity_score'], reverse=True) | |
| return similar_items | |
| except Exception as e: | |
| logger.error(f"Multimodal search error: {e}") | |
| return [] | |
| def update_db_with_multimodal(): | |
| """DB๋ฅผ ๋ฉํฐ๋ชจ๋ฌ ๋ฐฉ์์ผ๋ก ์ ๋ฐ์ดํธ""" | |
| try: | |
| # ์ ์ปฌ๋ ์ ์์ฑ | |
| collection = setup_multimodal_collection() | |
| # ๊ธฐ์กด ์ปฌ๋ ์ ์์ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ | |
| client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa") | |
| old_collection = client.get_collection("clothes") | |
| old_data = old_collection.get(include=['metadatas']) | |
| total_items = len(old_data['metadatas']) | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| batch_size = 100 | |
| successful_updates = 0 | |
| failed_updates = 0 | |
| for i in range(0, total_items, batch_size): | |
| batch = old_data['metadatas'][i:i + batch_size] | |
| images = [] | |
| valid_metadatas = [] | |
| valid_ids = [] | |
| for metadata in batch: | |
| try: | |
| if 'image_url' in metadata: | |
| response = requests.get(metadata['image_url']) | |
| img = Image.open(io.BytesIO(response.content)).convert('RGB') | |
| images.append(np.array(img)) | |
| valid_metadatas.append(metadata) | |
| valid_ids.append(metadata.get('id', str(hash(metadata['image_url'])))) | |
| successful_updates += 1 | |
| except Exception as e: | |
| logger.error(f"Error processing image: {e}") | |
| failed_updates += 1 | |
| continue | |
| if images: | |
| collection.add( | |
| ids=valid_ids, | |
| images=images, | |
| metadatas=valid_metadatas | |
| ) | |
| # Update progress | |
| progress = (i + len(batch)) / total_items | |
| progress_bar.progress(progress) | |
| status_text.text(f"Processing: {i + len(batch)}/{total_items} items. " | |
| f"Success: {successful_updates}, Failed: {failed_updates}") | |
| status_text.text(f"Update completed. Successfully processed: {successful_updates}, " | |
| f"Failed: {failed_updates}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Multimodal DB update error: {e}") | |
| return False | |
| def show_similar_items(similar_items): | |
| """Display similar items in a structured format with similarity scores""" | |
| if not similar_items: | |
| st.warning("No similar items found.") | |
| return | |
| st.subheader("Similar Items:") | |
| items_per_row = 2 | |
| for i in range(0, len(similar_items), items_per_row): | |
| cols = st.columns(items_per_row) | |
| for j, col in enumerate(cols): | |
| if i + j < len(similar_items): | |
| item = similar_items[i + j] | |
| with col: | |
| try: | |
| if 'image_url' in item: | |
| st.image(item['image_url'], use_column_width=True) | |
| st.markdown(f"**Similarity: {item['similarity_score']:.1f}%**") | |
| st.write(f"Brand: {item.get('brand', 'Unknown')}") | |
| name = item.get('name', 'Unknown') | |
| if len(name) > 50: | |
| name = name[:47] + "..." | |
| st.write(f"Name: {name}") | |
| price = item.get('price', 0) | |
| if isinstance(price, (int, float)): | |
| st.write(f"Price: {price:,}์") | |
| else: | |
| st.write(f"Price: {price}") | |
| if 'discount' in item and item['discount']: | |
| st.write(f"Discount: {item['discount']}%") | |
| if 'original_price' in item: | |
| st.write(f"Original: {item['original_price']:,}์") | |
| st.divider() | |
| except Exception as e: | |
| logger.error(f"Error displaying item: {e}") | |
| st.error("Error displaying this item") | |
| def process_search(image, mask, num_results): | |
| """์ ์ฌ ์์ดํ ๊ฒ์ ์ฒ๋ฆฌ""" | |
| try: | |
| with st.spinner("Finding similar items..."): | |
| similar_items = search_similar_items(image, mask, num_results) | |
| return similar_items | |
| except Exception as e: | |
| logger.error(f"Search processing error: {e}") | |
| return None | |
| def handle_file_upload(): | |
| if st.session_state.uploaded_file is not None: | |
| image = Image.open(st.session_state.uploaded_file).convert('RGB') | |
| st.session_state.image = image | |
| st.session_state.upload_state = 'image_uploaded' | |
| st.rerun() | |
| def handle_detection(): | |
| if st.session_state.image is not None: | |
| detected_items = process_segmentation(st.session_state.image) | |
| st.session_state.detected_items = detected_items | |
| st.session_state.upload_state = 'items_detected' | |
| st.rerun() | |
| def handle_search(): | |
| st.session_state.search_clicked = True | |
| def main(): | |
| st.title("Fashion Search App") | |
| # Admin controls in sidebar | |
| st.sidebar.title("Admin Controls") | |
| if st.sidebar.checkbox("Show Admin Interface"): | |
| if st.sidebar.button("Update Database (Multimodal)"): | |
| with st.spinner("Updating database with multimodal support..."): | |
| success = update_db_with_multimodal() | |
| if success: | |
| st.sidebar.success("Database updated successfully!") | |
| else: | |
| st.sidebar.error("Failed to update database") | |
| st.divider() | |
| # ํ์ผ ์ ๋ก๋ | |
| if st.session_state.upload_state == 'initial': | |
| uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'], | |
| key='uploaded_file', on_change=handle_file_upload) | |
| # ์ด๋ฏธ์ง๊ฐ ์ ๋ก๋๋ ์ํ | |
| if st.session_state.image is not None: | |
| st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True) | |
| if st.session_state.detected_items is None: | |
| if st.button("Detect Items", key='detect_button', on_click=handle_detection): | |
| pass | |
| # ๊ฒ์ถ๋ ์์ดํ ํ์ ๋ฐ ๊ฒ์ | |
| if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0: | |
| cols = st.columns(2) | |
| for idx, item in enumerate(st.session_state.detected_items): | |
| with cols[idx % 2]: | |
| try: | |
| if item.get('mask') is not None: | |
| masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2) | |
| st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}") | |
| st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}") | |
| score = item.get('score') | |
| if score is not None and isinstance(score, (int, float)): | |
| st.write(f"Confidence: {score*100:.1f}%") | |
| else: | |
| st.write("Confidence: N/A") | |
| except Exception as e: | |
| logger.error(f"Error displaying item {idx}: {str(e)}") | |
| st.error(f"Error displaying item {idx}") | |
| valid_items = [i for i in range(len(st.session_state.detected_items)) | |
| if st.session_state.detected_items[i].get('mask') is not None] | |
| if not valid_items: | |
| st.warning("No valid items detected for search.") | |
| return | |
| selected_idx = st.selectbox( | |
| "Select item to search:", | |
| valid_items, | |
| format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}", | |
| key='item_selector' | |
| ) | |
| search_col1, search_col2 = st.columns([1, 2]) | |
| with search_col1: | |
| search_clicked = st.button("Search Similar Items", | |
| key='search_button', | |
| type="primary") | |
| with search_col2: | |
| num_results = st.slider("Number of results:", | |
| min_value=1, | |
| max_value=20, | |
| value=5, | |
| key='num_results') | |
| if search_clicked or st.session_state.get('search_clicked', False): | |
| st.session_state.search_clicked = True | |
| selected_item = st.session_state.detected_items[selected_idx] | |
| if selected_item.get('mask') is None: | |
| st.error("Selected item has no valid mask for search.") | |
| return | |
| if 'search_results' not in st.session_state: | |
| similar_items = process_search(st.session_state.image, | |
| selected_item['mask'], | |
| num_results) | |
| st.session_state.search_results = similar_items | |
| if st.session_state.search_results: | |
| show_similar_items(st.session_state.search_results) | |
| else: | |
| st.warning("No similar items found.") | |
| # ์ ๊ฒ์ ๋ฒํผ | |
| if st.button("Start New Search", key='new_search'): | |
| for key in list(st.session_state.keys()): | |
| del st.session_state[key] | |
| st.rerun() | |
| if __name__ == "__main__": | |
| print('์์') | |
| main() |