| import chromadb |
| import logging |
| import open_clip |
| import torch |
| from PIL import Image |
| import numpy as np |
| from transformers import pipeline |
| import requests |
| import io |
| import json |
| import uuid |
| from concurrent.futures import ThreadPoolExecutor |
| from tqdm import tqdm |
| import os |
| from io import BytesIO |
| from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction |
| from chromadb.utils.data_loaders import ImageLoader |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| handlers=[ |
| logging.FileHandler('fashion_db_creation.log'), |
| logging.StreamHandler() |
| ] |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| def load_models(): |
| try: |
| logger.info("Loading models...") |
| |
| model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') |
| |
| |
| segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| |
| |
| from torchvision import transforms |
| resize_transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| ]) |
| |
| return model, preprocess_val, segmenter, device, resize_transform |
| except Exception as e: |
| logger.error(f"Error loading models: {e}") |
| raise |
|
|
| def process_segmentation(image, segmenter): |
| """Segmentation processing""" |
| try: |
| output = segmenter(image) |
| |
| if not output: |
| logger.warning("No segments found in image") |
| return None |
| |
| segment_sizes = [np.sum(seg['mask']) for seg in output] |
| |
| if not segment_sizes: |
| return None |
| |
| largest_idx = np.argmax(segment_sizes) |
| mask = output[largest_idx]['mask'] |
| |
| if not isinstance(mask, np.ndarray): |
| mask = np.array(mask) |
| |
| if len(mask.shape) > 2: |
| mask = mask[:, :, 0] |
| |
| mask = mask.astype(float) |
| |
| logger.info(f"Successfully created mask with shape {mask.shape}") |
| return mask |
| |
| except Exception as e: |
| logger.error(f"Segmentation error: {str(e)}") |
| return None |
|
|
| def load_image_from_url(url, max_retries=3): |
| for attempt in range(max_retries): |
| try: |
| response = requests.get(url, timeout=10) |
| response.raise_for_status() |
| img = Image.open(BytesIO(response.content)).convert('RGB') |
| return img |
| except Exception as e: |
| logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") |
| if attempt < max_retries - 1: |
| time.sleep(1) |
| else: |
| logger.error(f"Failed to load image from {url} after {max_retries} attempts") |
| return None |
|
|
| def extract_features(image, mask, model, preprocess_val, device): |
| """Advanced feature extraction with mask-based attention""" |
| try: |
| img_array = np.array(image) |
| mask = np.expand_dims(mask, axis=2) |
| mask_3channel = np.repeat(mask, 3, axis=2) |
| |
| |
| image_tensor_original = preprocess_val(image).unsqueeze(0).to(device) |
| |
| |
| masked_img_white = img_array * mask_3channel + (1 - mask_3channel) * 255 |
| image_masked_white = Image.fromarray(masked_img_white.astype(np.uint8)) |
| image_tensor_masked = preprocess_val(image_masked_white).unsqueeze(0).to(device) |
| |
| |
| bbox = get_bbox_from_mask(mask) |
| cropped_img = crop_and_resize(img_array * mask_3channel, bbox) |
| image_cropped = Image.fromarray(cropped_img.astype(np.uint8)) |
| image_tensor_cropped = preprocess_val(image_cropped).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| |
| features_original = model.encode_image(image_tensor_original) |
| features_masked = model.encode_image(image_tensor_masked) |
| features_cropped = model.encode_image(image_tensor_cropped) |
| |
| |
| combined_features = ( |
| 0.2 * features_original + |
| 0.3 * features_masked + |
| 0.5 * features_cropped |
| ) |
| |
| |
| combined_features /= combined_features.norm(dim=-1, keepdim=True) |
| |
| return combined_features.cpu().numpy().flatten() |
| |
| except Exception as e: |
| logger.error(f"Feature extraction error: {e}") |
| return None |
|
|
| def get_bbox_from_mask(mask): |
| """마스크로부터 경계상자 좌표 추출""" |
| rows = np.any(mask, axis=1) |
| cols = np.any(mask, axis=0) |
| rmin, rmax = np.where(rows)[0][[0, -1]] |
| cmin, cmax = np.where(cols)[0][[0, -1]] |
| |
| padding = 10 |
| rmin = max(rmin - padding, 0) |
| rmax = min(rmax + padding, mask.shape[0]) |
| cmin = max(cmin - padding, 0) |
| cmax = min(cmax + padding, mask.shape[1]) |
| return rmin, rmax, cmin, cmax |
|
|
| def crop_and_resize(image, bbox): |
| """경계상자로 이미지 크롭 및 리사이즈""" |
| rmin, rmax, cmin, cmax = bbox |
| cropped = image[rmin:rmax, cmin:cmax] |
| |
| size = max(cropped.shape[:2]) |
| square_img = np.full((size, size, 3), 255, dtype=np.uint8) |
| start_h = (size - cropped.shape[0]) // 2 |
| start_w = (size - cropped.shape[1]) // 2 |
| square_img[start_h:start_h+cropped.shape[0], |
| start_w:start_w+cropped.shape[1]] = cropped |
| return square_img |
|
|
| def process_item(item, model, preprocess_val, segmenter, device, resize_transform): |
| """Process single item from JSON data""" |
| try: |
| |
| if '이미지 링크' in item: |
| image_url = item['이미지 링크'] |
| elif '이미지 URL' in item: |
| image_url = item['이미지 URL'] |
| else: |
| logger.warning(f"No image URL found in item") |
| return None |
|
|
| |
| metadata = create_metadata(item) |
| |
| |
| image = load_image_from_url(image_url) |
| if image is None: |
| logger.warning(f"Failed to load image from {image_url}") |
| return None |
|
|
| |
| mask = process_segmentation(image, segmenter) |
| if mask is None: |
| logger.warning(f"Failed to create segmentation mask for {image_url}") |
| return None |
|
|
| |
| try: |
| features = extract_features(image, mask, model, preprocess_val, device) |
| if features is None: |
| raise ValueError("Feature extraction failed") |
| |
| |
| |
| |
| except Exception as e: |
| logger.error(f"Feature extraction failed for {image_url}: {str(e)}") |
| return None |
|
|
| return { |
| 'id': metadata['product_id'], |
| 'embedding': features.tolist(), |
| 'metadata': metadata, |
| 'image_uri': image_url |
| } |
|
|
| except Exception as e: |
| logger.error(f"Error processing item: {str(e)}") |
| return None |
|
|
| |
| def save_debug_images(image, mask, url): |
| try: |
| debug_dir = "debug_images" |
| os.makedirs(debug_dir, exist_ok=True) |
| |
| |
| filename = url.split('/')[-1].split('?')[0] |
| |
| |
| image.save(f"{debug_dir}/original_{filename}") |
| |
| mask_img = Image.fromarray((mask * 255).astype(np.uint8)) |
| mask_img.save(f"{debug_dir}/mask_{filename}") |
| |
| except Exception as e: |
| logger.warning(f"Failed to save debug images: {str(e)}") |
|
|
| def create_metadata(item): |
| """Create standardized metadata from different JSON formats""" |
| metadata = {} |
| |
| |
| if '상품 ID' in item: |
| metadata['product_id'] = item['상품 ID'] |
| else: |
| |
| unique_string = f"{item.get('상품명', '')}{item.get('이미지 URL', '')}" |
| metadata['product_id'] = str(hash(unique_string)) |
| |
| |
| metadata['brand'] = item.get('브랜드명', 'unknown') |
| metadata['name'] = item.get('제품명') or item.get('상품명', 'unknown') |
| metadata['price'] = (item.get('정가') or item.get('가격') or |
| item.get('판매가', 'unknown')) |
| metadata['discount'] = item.get('할인율', 'unknown') |
| |
| if '카테고리' in item: |
| if isinstance(item['카테고리'], list): |
| metadata['category'] = '/'.join(item['카테고리']) |
| else: |
| metadata['category'] = item['카테고리'] |
| else: |
| |
| name = metadata['name'].lower() |
| categories = ['원피스', '셔츠', '블라우스', '니트', '가디건', |
| '스커트', '팬츠', '셋업', '아우터', '자켓'] |
| found_categories = [cat for cat in categories if cat in name] |
| metadata['category'] = '/'.join(found_categories) if found_categories else 'unknown' |
| |
| metadata['image_url'] = (item.get('이미지 링크') or |
| item.get('이미지 URL', 'unknown')) |
| |
| |
| if '이미지 링크' in item: |
| metadata['source'] = 'musinsa' |
| elif 'cdn.011st.com' in metadata['image_url']: |
| metadata['source'] = '11st' |
| elif 'gmarket' in metadata['image_url']: |
| metadata['source'] = 'gmarket' |
| else: |
| metadata['source'] = 'unknown' |
| |
| return metadata |
|
|
| def create_multimodal_fashion_db(json_files): |
| try: |
| logger.info("Starting multimodal fashion database creation") |
| |
| |
| model, preprocess_val, segmenter, device, resize_transform = load_models() |
| |
| |
| client = chromadb.PersistentClient(path="./fashion_multimodal_db") |
| |
| |
| embedding_function = OpenCLIPEmbeddingFunction() |
| data_loader = ImageLoader() |
| |
| try: |
| client.delete_collection("fashion_multimodal") |
| logger.info("Deleted existing collection") |
| except: |
| logger.info("No existing collection to delete") |
| |
| collection = client.create_collection( |
| name="fashion_multimodal", |
| embedding_function=embedding_function, |
| data_loader=data_loader, |
| metadata={"description": "Fashion multimodal collection with advanced feature extraction"} |
| ) |
| |
| |
| stats = { |
| 'total_processed': 0, |
| 'successful': 0, |
| 'failed': 0, |
| 'feature_extraction_failed': 0 |
| } |
| |
| |
| for json_file in json_files: |
| with open(json_file, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| |
| logger.info(f"Processing {len(data)} items from {json_file}") |
| |
| with ThreadPoolExecutor(max_workers=4) as executor: |
| futures = [] |
| for item in data: |
| future = executor.submit( |
| process_item, |
| item, model, preprocess_val, segmenter, device, resize_transform |
| ) |
| futures.append(future) |
| |
| processed_items = [] |
| for future in tqdm(futures, desc=f"Processing {json_file}"): |
| stats['total_processed'] += 1 |
| result = future.result() |
| |
| if result is not None: |
| processed_items.append(result) |
| stats['successful'] += 1 |
| else: |
| stats['failed'] += 1 |
| |
| |
| if processed_items: |
| try: |
| collection.add( |
| ids=[item['id'] for item in processed_items], |
| embeddings=[item['embedding'] for item in processed_items], |
| metadatas=[item['metadata'] for item in processed_items], |
| uris=[item['image_uri'] for item in processed_items] |
| ) |
| except Exception as e: |
| logger.error(f"Failed to add batch to collection: {str(e)}") |
| stats['failed'] += len(processed_items) |
| stats['successful'] -= len(processed_items) |
| |
| |
| logger.info("Processing completed:") |
| logger.info(f"Total processed: {stats['total_processed']}") |
| logger.info(f"Successful: {stats['successful']}") |
| logger.info(f"Failed: {stats['failed']}") |
| |
| return stats['successful'] > 0 |
| |
| except Exception as e: |
| logger.error(f"Database creation error: {str(e)}") |
| return False |
|
|
| if __name__ == "__main__": |
| json_files = [ |
| './musinsa_ranking_images_category_0920.json', |
| './11st/11st_bagaccessory_20241017_172846.json', |
| './11st/11st_best_abroad_bagaccessory_20241017_173300.json', |
| './11st/11st_best_abroad_fashion_20241017_173144.json', |
| './11st/11st_best_abroad_luxury_20241017_173343.json', |
| './11st/11st_best_men_20241017_172534.json', |
| './11st/11st_best_women_20241017_172127.json', |
| './gmarket/gmarket_best_accessory_20241015_155921.json', |
| './gmarket/gmarket_best_bag_20241015_155811.json', |
| './gmarket/gmarket_best_brand_20241015_155530.json', |
| './gmarket/gmarket_best_casual_20241015_155421.json', |
| './gmarket/gmarket_best_men_20241015_155025.json', |
| './gmarket/gmarket_best_shoe_20241015_155613.json', |
| './gmarket/gmarket_best_women_20241015_154206.json' |
| ] |
| |
| success = create_multimodal_fashion_db(json_files) |
| |
| if success: |
| print("Successfully created multimodal fashion database!") |
| else: |
| print("Failed to create database. Check the logs for details.") |