Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import open_clip | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from transformers import pipeline | |
| import chromadb | |
| import logging | |
| # λ‘κΉ μ€μ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize session state | |
| if 'image' not in st.session_state: | |
| st.session_state.image = None | |
| if 'detected_items' not in st.session_state: | |
| st.session_state.detected_items = None | |
| if 'selected_item_index' not in st.session_state: | |
| st.session_state.selected_item_index = None | |
| if 'upload_state' not in st.session_state: | |
| st.session_state.upload_state = 'initial' | |
| # Load models μλ | |
| def load_models(): | |
| try: | |
| # CLIP λͺ¨λΈ | |
| model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') | |
| # μΈκ·Έλ©ν μ΄μ λͺ¨λΈ | |
| segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| return model, preprocess_val, segmenter, device | |
| except Exception as e: | |
| logger.error(f"Error loading models: {e}") | |
| raise | |
| # λͺ¨λΈ λ‘λ | |
| clip_model, preprocess_val, segmenter, device = load_models() | |
| # ChromaDB μ€μ | |
| client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa") | |
| collection = client.get_collection(name="clothes") | |
| def process_segmentation(image): | |
| """Segmentation processing μλ νμΈμ""" | |
| try: | |
| segments = segmenter(image) | |
| valid_items = [] | |
| for s in segments: | |
| mask_array = np.array(s['mask']) | |
| confidence = np.mean(mask_array) | |
| valid_items.append({ | |
| 'score': confidence, | |
| 'label': s['label'], | |
| 'mask': mask_array | |
| }) | |
| return valid_items | |
| except Exception as e: | |
| logger.error(f"Segmentation error: {e}") | |
| return [] | |
| def extract_features(image, mask=None): | |
| """Extract CLIP features""" | |
| try: | |
| if mask is not None: | |
| img_array = np.array(image) | |
| mask = np.expand_dims(mask, axis=2) | |
| masked_img = img_array * mask | |
| masked_img[mask[:,:,0] == 0] = 255 | |
| image = Image.fromarray(masked_img.astype(np.uint8)) | |
| image_tensor = preprocess_val(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| features = clip_model.encode_image(image_tensor) | |
| features /= features.norm(dim=-1, keepdim=True) | |
| return features.cpu().numpy().flatten() | |
| except Exception as e: | |
| logger.error(f"Feature extraction error: {e}") | |
| raise | |
| def search_similar_items(features, top_k=10): | |
| """Search similar items with distance scores""" | |
| try: | |
| results = collection.query( | |
| query_embeddings=[features.tolist()], | |
| n_results=top_k, | |
| include=['metadatas', 'distances'] # distances ν¬ν¨ | |
| ) | |
| similar_items = [] | |
| for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): | |
| # 거리λ₯Ό μ μ¬λ μ μλ‘ λ³ν (0~1 λ²μ) | |
| similarity_score = 1 / (1 + distance) | |
| metadata['similarity_score'] = similarity_score # λ©νλ°μ΄ν°μ μ μ μΆκ° | |
| similar_items.append(metadata) | |
| return similar_items | |
| except Exception as e: | |
| logger.error(f"Search error: {e}") | |
| return [] | |
| def show_similar_items(similar_items): | |
| """Display similar items in a structured format with similarity scores""" | |
| st.subheader("Similar Items:") | |
| for item in similar_items: | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.image(item['image_url']) | |
| with col2: | |
| # μ μ¬λ μ μλ₯Ό νΌμΌνΈλ‘ νμ | |
| similarity_percent = item['similarity_score'] * 100 | |
| st.write(f"Similarity: {similarity_percent:.1f}%") | |
| st.write(f"Brand: {item.get('brand', 'Unknown')}") | |
| st.write(f"Name: {item.get('name', 'Unknown')}") | |
| st.write(f"Price: {item.get('price', 'Unknown'):,}μ") | |
| if 'discount' in item: | |
| st.write(f"Discount: {item['discount']}%") | |
| if 'original_price' in item: | |
| st.write(f"Original Price: {item['original_price']:,}μ") | |
| # Initialize session state | |
| if 'image' not in st.session_state: | |
| st.session_state.image = None | |
| if 'detected_items' not in st.session_state: | |
| st.session_state.detected_items = None | |
| if 'selected_item_index' not in st.session_state: | |
| st.session_state.selected_item_index = None | |
| if 'upload_state' not in st.session_state: | |
| st.session_state.upload_state = 'initial' | |
| if 'search_clicked' not in st.session_state: | |
| st.session_state.search_clicked = False | |
| def reset_state(): | |
| """Reset all session state variables""" | |
| for key in list(st.session_state.keys()): | |
| del st.session_state[key] | |
| # Callback functions | |
| def handle_file_upload(): | |
| if st.session_state.uploaded_file is not None: | |
| image = Image.open(st.session_state.uploaded_file).convert('RGB') | |
| st.session_state.image = image | |
| st.session_state.upload_state = 'image_uploaded' | |
| st.rerun() | |
| def handle_detection(): | |
| if st.session_state.image is not None: | |
| detected_items = process_segmentation(st.session_state.image) | |
| st.session_state.detected_items = detected_items | |
| st.session_state.upload_state = 'items_detected' | |
| st.rerun() | |
| def handle_search(): | |
| st.session_state.search_clicked = True | |
| def main(): | |
| st.title("ν¬μ΄λΈλ fashion demo!!!") | |
| # νμΌ μ λ‘λ (upload_stateκ° initialμΌ λλ§ νμ) | |
| if st.session_state.upload_state == 'initial': | |
| uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'], | |
| key='uploaded_file', on_change=handle_file_upload) | |
| # μ΄λ―Έμ§κ° μ λ‘λλ μν df | |
| if st.session_state.image is not None: | |
| st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True) | |
| if st.session_state.detected_items is None: | |
| if st.button("Detect Items", key='detect_button', on_click=handle_detection): | |
| pass | |
| # κ²μΆλ μμ΄ν νμd | |
| if st.session_state.detected_items: | |
| # κ°μ§λ μμ΄ν λ€dμ 2μ΄λ‘ νμ | |
| cols = st.columns(2) | |
| for idx, item in enumerate(st.session_state.detected_items): | |
| with cols[idx % 2]: | |
| mask = item['mask'] | |
| masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2) | |
| st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}") | |
| st.write(f"Item {idx + 1}: {item['label']}") | |
| st.write(f"Confidence: {item['score']*100:.1f}%") | |
| # μμ΄ν μ ν | |
| selected_idx = st.selectbox( | |
| "Select item to search:", | |
| range(len(st.session_state.detected_items)), | |
| format_func=lambda i: f"{st.session_state.detected_items[i]['label']}", | |
| key='item_selector' | |
| ) | |
| st.session_state.selected_item_index = selected_idx | |
| # μ μ¬ μμ΄ν κ²μ | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| search_button = st.button("Search Similar Items", | |
| key='search_button', | |
| on_click=handle_search, | |
| type="primary") # κ°μ‘°λ λ²νΌ | |
| with col2: | |
| num_results = st.slider("Number of results:", | |
| min_value=1, | |
| max_value=20, | |
| value=5, | |
| key='num_results') | |
| if st.session_state.search_clicked: | |
| with st.spinner("Searching similar items..."): | |
| try: | |
| selected_mask = st.session_state.detected_items[selected_idx]['mask'] | |
| features = extract_features(st.session_state.image, selected_mask) | |
| similar_items = search_similar_items(features, top_k=num_results) | |
| if similar_items: | |
| show_similar_items(similar_items) | |
| else: | |
| st.warning("No similar items found.") | |
| except Exception as e: | |
| st.error(f"Error during search: {str(e)}") | |
| # μ κ²μ λ²νΌ | |
| if st.button("Start New Search ", key='new_search'): | |
| reset_state() | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() |