itda-segment / app.py
leedoming's picture
Update app.py
ee78aad verified
import streamlit as st
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import pipeline
import chromadb
import logging
import io
import requests
from concurrent.futures import ThreadPoolExecutor
# λ‘œκΉ… μ„€μ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize session state
if 'image' not in st.session_state:
st.session_state.image = None
if 'detected_items' not in st.session_state:
st.session_state.detected_items = None
if 'selected_item_index' not in st.session_state:
st.session_state.selected_item_index = None
if 'upload_state' not in st.session_state:
st.session_state.upload_state = 'initial'
if 'search_clicked' not in st.session_state:
st.session_state.search_clicked = False
# Load models
@st.cache_resource
def load_models():
try:
# CLIP λͺ¨λΈ
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
# μ„Έκ·Έλ©˜ν…Œμ΄μ…˜ λͺ¨λΈ
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, preprocess_val, segmenter, device
except Exception as e:
logger.error(f"Error loading models: {e}")
raise
# λͺ¨λΈ λ‘œλ“œ
clip_model, preprocess_val, segmenter, device = load_models()
# ChromaDB μ„€μ •
client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
collection = client.get_collection(name="clothes")
def extract_color_histogram(image, mask=None):
"""Extract color histogram from the image, considering the mask if provided"""
try:
img_array = np.array(image)
if mask is not None:
# Reshape mask to match image dimensions
mask = np.expand_dims(mask, axis=-1) # Add channel dimension
img_array = img_array * mask # Broadcasting will work correctly now
# Only consider pixels that are part of the clothing item
valid_pixels = img_array[mask[:,:,0] > 0]
else:
valid_pixels = img_array.reshape(-1, 3)
# Convert to HSV color space for better color representation
if len(valid_pixels) > 0:
# Reshape to proper dimensions for PIL Image
valid_pixels = valid_pixels.reshape(-1, 3)
img_hsv = Image.fromarray(valid_pixels.astype(np.uint8)).convert('HSV')
hsv_pixels = np.array(img_hsv)
# Calculate histogram for each HSV channel
h_hist = np.histogram(hsv_pixels[:,0], bins=8, range=(0, 256))[0]
s_hist = np.histogram(hsv_pixels[:,1], bins=8, range=(0, 256))[0]
v_hist = np.histogram(hsv_pixels[:,2], bins=8, range=(0, 256))[0]
# Normalize histograms
h_hist = h_hist / (h_hist.sum() + 1e-8) # Add small epsilon to avoid division by zero
s_hist = s_hist / (s_hist.sum() + 1e-8)
v_hist = v_hist / (v_hist.sum() + 1e-8)
return np.concatenate([h_hist, s_hist, v_hist])
return np.zeros(24) # 8bins * 3channels = 24 features
except Exception as e:
logger.error(f"Color histogram extraction error: {e}")
return np.zeros(24)
def process_segmentation(image):
"""Segmentation processing"""
try:
# pipeline 좜λ ₯ κ²°κ³Ό 직접 처리
output = segmenter(image)
if not output or len(output) == 0:
logger.warning("No segments found in image")
return []
processed_items = []
for segment in output:
# 기본값을 ν¬ν•¨ν•˜μ—¬ λ”•μ…”λ„ˆλ¦¬ 생성
processed_segment = {
'label': segment.get('label', 'Unknown'),
'score': segment.get('score', 1.0), # scoreκ°€ μ—†μœΌλ©΄ 1.0을 κΈ°λ³Έκ°’μœΌλ‘œ μ‚¬μš©
'mask': None
}
mask = segment.get('mask')
if mask is not None:
# λ§ˆμŠ€ν¬κ°€ numpy arrayκ°€ μ•„λ‹Œ 경우 λ³€ν™˜
if not isinstance(mask, np.ndarray):
mask = np.array(mask)
# λ§ˆμŠ€ν¬κ°€ 2Dκ°€ μ•„λ‹Œ 경우 첫 번째 채널 μ‚¬μš©
if len(mask.shape) > 2:
mask = mask[:, :, 0]
# bool 마슀크λ₯Ό float둜 λ³€ν™˜
processed_segment['mask'] = mask.astype(float)
else:
logger.warning(f"No mask found for segment with label {processed_segment['label']}")
continue # λ§ˆμŠ€ν¬κ°€ μ—†λŠ” μ„Έκ·Έλ¨ΌνŠΈλŠ” κ±΄λ„ˆλœ€
processed_items.append(processed_segment)
logger.info(f"Successfully processed {len(processed_items)} segments")
return processed_items
except Exception as e:
logger.error(f"Segmentation error: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return []
def extract_features(image, mask=None):
"""Extract both CLIP features and color features with segmentation mask"""
try:
# Extract CLIP features
if mask is not None:
img_array = np.array(image)
mask = np.expand_dims(mask, axis=-1)
masked_img = img_array * mask
masked_img[mask[:,:,0] == 0] = 255 # Set background to white
image = Image.fromarray(masked_img.astype(np.uint8))
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
with torch.no_grad():
clip_features = clip_model.encode_image(image_tensor)
clip_features /= clip_features.norm(dim=-1, keepdim=True)
clip_features = clip_features.cpu().numpy().flatten()
# Extract color features
color_features = extract_color_histogram(image, mask)
# CLIP features are 768-dimensional, so we'll resize color features
# to maintain the same total dimensionality
clip_features = clip_features[:744] # Trim CLIP features to make room for color
# Normalize features
clip_features_normalized = clip_features / (np.linalg.norm(clip_features) + 1e-8)
color_features_normalized = color_features / (np.linalg.norm(color_features) + 1e-8)
# Adjust weights (total should be 768 to match collection dimensionality)
clip_weight = 0.7
color_weight = 0.3
combined_features = np.zeros(768) # Initialize with zeros
combined_features[:744] = clip_features_normalized * clip_weight # First 744 dimensions for CLIP
combined_features[744:] = color_features_normalized * color_weight # Last 24 dimensions for color
# Ensure final normalization
combined_features = combined_features / (np.linalg.norm(combined_features) + 1e-8)
return combined_features
except Exception as e:
logger.error(f"Feature extraction error: {e}")
raise
def download_and_process_image(image_url, metadata_id):
"""Download image from URL and apply segmentation"""
try:
response = requests.get(image_url, timeout=10)
if response.status_code != 200:
logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
return None
image = Image.open(io.BytesIO(response.content)).convert('RGB')
logger.info(f"Successfully downloaded image {metadata_id}")
processed_items = process_segmentation(image)
if processed_items and len(processed_items) > 0:
# κ°€μž₯ 큰 μ„Έκ·Έλ¨ΌνŠΈμ˜ 마슀크 μ‚¬μš©
largest_mask = max(processed_items, key=lambda x: np.sum(x['mask']))['mask']
features = extract_features(image, largest_mask)
logger.info(f"Successfully extracted features for image {metadata_id}")
return features
logger.warning(f"No valid mask found for image {metadata_id}")
return None
except Exception as e:
logger.error(f"Error processing image {metadata_id}: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return None
def update_db_with_segmentation():
"""DB의 λͺ¨λ“  이미지에 λŒ€ν•΄ segmentation을 μ μš©ν•˜κ³  featureλ₯Ό μ—…λ°μ΄νŠΈ"""
try:
logger.info("Starting database update with segmentation and color features")
# μƒˆλ‘œμš΄ collection 생성
try:
client.delete_collection("clothes_segmented")
logger.info("Deleted existing segmented collection")
except:
logger.info("No existing segmented collection to delete")
new_collection = client.create_collection(
name="clothes_segmented",
metadata={"description": "Clothes collection with segmentation and color features"}
)
logger.info("Created new segmented collection")
# κΈ°μ‘΄ collectionμ—μ„œ λ©”νƒ€λ°μ΄ν„°λ§Œ κ°€μ Έμ˜€κΈ°
try:
all_items = collection.get(include=['metadatas'])
total_items = len(all_items['metadatas'])
logger.info(f"Found {total_items} items in database")
except Exception as e:
logger.error(f"Error getting items from collection: {str(e)}")
all_items = {'metadatas': []}
total_items = 0
# μ§„ν–‰ 상황 ν‘œμ‹œλ₯Ό μœ„ν•œ progress bar
progress_bar = st.progress(0)
status_text = st.empty()
successful_updates = 0
failed_updates = 0
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
# 이미지 URL이 μžˆλŠ” ν•­λͺ©λ§Œ 처리
valid_items = [m for m in all_items['metadatas'] if 'image_url' in m]
for metadata in valid_items:
future = executor.submit(
download_and_process_image,
metadata['image_url'],
metadata.get('id', 'unknown')
)
futures.append((metadata, future))
# κ²°κ³Ό 처리 및 μƒˆ DB에 μ €μž₯
for idx, (metadata, future) in enumerate(futures):
try:
new_features = future.result()
if new_features is not None:
item_id = metadata.get('id', str(hash(metadata['image_url'])))
try:
new_collection.add(
embeddings=[new_features.tolist()],
metadatas=[metadata],
ids=[item_id]
)
successful_updates += 1
logger.info(f"Successfully added item {item_id}")
except Exception as e:
logger.error(f"Error adding item to new collection: {str(e)}")
failed_updates += 1
else:
failed_updates += 1
# μ§„ν–‰ 상황 μ—…λ°μ΄νŠΈ
progress = (idx + 1) / len(futures)
progress_bar.progress(progress)
status_text.text(f"Processing: {idx + 1}/{len(futures)} items. Success: {successful_updates}, Failed: {failed_updates}")
except Exception as e:
logger.error(f"Error processing item: {str(e)}")
failed_updates += 1
continue
# μ΅œμ’… κ²°κ³Ό ν‘œμ‹œ
status_text.text(f"Update completed. Successfully processed: {successful_updates}, Failed: {failed_updates}")
logger.info(f"Database update completed. Successful: {successful_updates}, Failed: {failed_updates}")
# μ„±κ³΅μ μœΌλ‘œ 처리된 ν•­λͺ©μ΄ μžˆλŠ”μ§€ 확인
if successful_updates > 0:
return True
else:
logger.error("No items were successfully processed")
return False
except Exception as e:
logger.error(f"Database update error: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return False
def search_similar_items(features, top_k=10):
"""Search similar items using combined features"""
try:
# μ„Έκ·Έλ©˜ν…Œμ΄μ…˜μ΄ 적용된 collection이 μžˆλŠ”μ§€ 확인
try:
search_collection = client.get_collection("clothes_segmented")
logger.info("Using segmented collection for search")
except:
# μ—†μœΌλ©΄ κΈ°μ‘΄ collection μ‚¬μš©
search_collection = collection
logger.info("Using original collection for search")
results = search_collection.query(
query_embeddings=[features.tolist()],
n_results=top_k,
include=['metadatas', 'scores']
)
if not results or not results['metadatas'] or not results['scores']:
logger.warning("No results returned from ChromaDB")
return []
similar_items = []
for metadata, distance in zip(results['metadatas'][0], results['scores'][0]):
try:
similarity_score = distance
item_data = metadata.copy()
item_data['similarity_score'] = similarity_score
similar_items.append(item_data)
except Exception as e:
logger.error(f"Error processing search result: {str(e)}")
continue
similar_items.sort(key=lambda x: x['similarity_score'], reverse=True)
return similar_items
except Exception as e:
logger.error(f"Search error: {str(e)}")
return []
def show_similar_items(similar_items):
"""Display similar items in a structured format with similarity scores"""
if not similar_items:
st.warning("No similar items found.")
return
st.subheader("Similar Items:")
# κ²°κ³Όλ₯Ό 2μ—΄λ‘œ ν‘œμ‹œ
items_per_row = 2
for i in range(0, len(similar_items), items_per_row):
cols = st.columns(items_per_row)
for j, col in enumerate(cols):
if i + j < len(similar_items):
item = similar_items[i + j]
with col:
try:
if 'image_url' in item:
st.image(item['image_url'], use_column_width=True)
# μœ μ‚¬λ„ 점수λ₯Ό νΌμ„ΌνŠΈλ‘œ ν‘œμ‹œ
similarity_percent = item['similarity_score']
st.markdown(f"**Similarity: {similarity_percent:.1f}%**")
st.write(f"Brand: {item.get('brand', 'Unknown')}")
name = item.get('name', 'Unknown')
if len(name) > 50: # κΈ΄ 이름은 μ€„μž„
name = name[:47] + "..."
st.write(f"Name: {name}")
# 가격 정보 ν‘œμ‹œ
price = item.get('price', 0)
if isinstance(price, (int, float)):
st.write(f"Price: {price:,}원")
else:
st.write(f"Price: {price}")
# 할인 정보가 μžˆλŠ” 경우
if 'discount' in item and item['discount']:
st.write(f"Discount: {item['discount']}%")
if 'original_price' in item:
st.write(f"Original: {item['original_price']:,}원")
st.divider() # ꡬ뢄선 μΆ”κ°€
except Exception as e:
logger.error(f"Error displaying item: {e}")
st.error("Error displaying this item")
def process_search(image, mask, num_results):
"""μœ μ‚¬ μ•„μ΄ν…œ 검색 처리"""
try:
with st.spinner("Extracting features..."):
features = extract_features(image, mask)
with st.spinner("Finding similar items..."):
similar_items = search_similar_items(features, top_k=num_results)
return similar_items
except Exception as e:
logger.error(f"Search processing error: {e}")
return None
def handle_file_upload():
if st.session_state.uploaded_file is not None:
image = Image.open(st.session_state.uploaded_file).convert('RGB')
st.session_state.image = image
st.session_state.upload_state = 'image_uploaded'
st.rerun()
def handle_detection():
if st.session_state.image is not None:
detected_items = process_segmentation(st.session_state.image)
st.session_state.detected_items = detected_items
st.session_state.upload_state = 'items_detected'
st.rerun()
def handle_search():
st.session_state.search_clicked = True
def main():
st.title("Fashion Search App")
# Admin controls in sidebar
st.sidebar.title("Admin Controls")
if st.sidebar.checkbox("Show Admin Interface"):
# Admin interface κ΅¬ν˜„ (ν•„μš”ν•œ 경우)
st.sidebar.warning("Admin interface is not implemented yet.")
st.divider()
# 파일 μ—…λ‘œλ”
if st.session_state.upload_state == 'initial':
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
key='uploaded_file', on_change=handle_file_upload)
# 이미지가 μ—…λ‘œλ“œλœ μƒνƒœ
if st.session_state.image is not None:
st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
if st.session_state.detected_items is None:
if st.button("Detect Items", key='detect_button', on_click=handle_detection):
pass
# κ²€μΆœλœ μ•„μ΄ν…œ ν‘œμ‹œ
if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0:
# κ°μ§€λœ μ•„μ΄ν…œλ“€μ„ 2μ—΄λ‘œ ν‘œμ‹œ
cols = st.columns(2)
for idx, item in enumerate(st.session_state.detected_items):
with cols[idx % 2]:
try:
if item.get('mask') is not None:
masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2)
st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}")
st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}")
# score 값이 있고 숫자인 κ²½μš°μ—λ§Œ ν‘œμ‹œ
score = item.get('score')
if score is not None and isinstance(score, (int, float)):
st.write(f"Confidence: {score*100:.1f}%")
else:
st.write("Confidence: N/A")
except Exception as e:
logger.error(f"Error displaying item {idx}: {str(e)}")
st.error(f"Error displaying item {idx}")
valid_items = [i for i in range(len(st.session_state.detected_items))
if st.session_state.detected_items[i].get('mask') is not None]
if not valid_items:
st.warning("No valid items detected for search.")
return
# μ•„μ΄ν…œ 선택
selected_idx = st.selectbox(
"Select item to search:",
valid_items,
format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}",
key='item_selector'
)
# 검색 컨트둀
search_col1, search_col2 = st.columns([1, 2])
with search_col1:
search_clicked = st.button("Search Similar Items",
key='search_button',
type="primary")
with search_col2:
num_results = st.slider("Number of results:",
min_value=1,
max_value=20,
value=5,
key='num_results')
# 검색 κ²°κ³Ό 처리
if search_clicked or st.session_state.get('search_clicked', False):
st.session_state.search_clicked = True
selected_item = st.session_state.detected_items[selected_idx]
if selected_item.get('mask') is None:
st.error("Selected item has no valid mask for search.")
return
# 검색 κ²°κ³Όλ₯Ό μ„Έμ…˜ μƒνƒœμ— μ €μž₯
if 'search_results' not in st.session_state:
similar_items = process_search(st.session_state.image, selected_item['mask'], num_results)
st.session_state.search_results = similar_items
# μ €μž₯된 검색 κ²°κ³Ό ν‘œμ‹œ
if st.session_state.search_results:
show_similar_items(st.session_state.search_results)
else:
st.warning("No similar items found.")
# μƒˆ 검색 λ²„νŠΌ
if st.button("Start New Search", key='new_search'):
# λͺ¨λ“  μƒνƒœ μ΄ˆκΈ°ν™”
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
if __name__ == "__main__":
main()