Spaces:
Sleeping
Sleeping
| import chromadb | |
| import logging | |
| import open_clip | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from transformers import pipeline | |
| import requests | |
| import io | |
| from concurrent.futures import ThreadPoolExecutor | |
| from tqdm import tqdm | |
| import os | |
| # 로깅 설정 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler('db_creation.log'), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def load_models(): | |
| """Load CLIP and segmentation models""" | |
| try: | |
| logger.info("Loading models...") | |
| # CLIP 모델 | |
| model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') | |
| # 세그멘테이션 모델 | |
| segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| model.to(device) | |
| return model, preprocess_val, segmenter, device | |
| except Exception as e: | |
| logger.error(f"Error loading models: {e}") | |
| raise | |
| def process_segmentation(image, segmenter): | |
| """Apply segmentation to image""" | |
| try: | |
| segments = segmenter(image) | |
| if not segments: | |
| return None | |
| # 가장 큰 세그먼트 선택 | |
| largest_segment = max(segments, key=lambda s: np.sum(s['mask'])) | |
| mask = np.array(largest_segment['mask']) | |
| return mask | |
| except Exception as e: | |
| logger.error(f"Segmentation error: {e}") | |
| return None | |
| def extract_features(image, mask, model, preprocess_val, device): | |
| """Extract CLIP features with segmentation mask""" | |
| try: | |
| if mask is not None: | |
| img_array = np.array(image) | |
| mask = np.expand_dims(mask, axis=2) | |
| masked_img = img_array * mask | |
| masked_img[mask[:,:,0] == 0] = 255 # 배경을 흰색으로 | |
| image = Image.fromarray(masked_img.astype(np.uint8)) | |
| image_tensor = preprocess_val(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| features = model.encode_image(image_tensor) | |
| features /= features.norm(dim=-1, keepdim=True) | |
| return features.cpu().numpy().flatten() | |
| except Exception as e: | |
| logger.error(f"Feature extraction error: {e}") | |
| return None | |
| def download_and_process_image(url, metadata_id, model, preprocess_val, segmenter, device): | |
| """Download and process single image""" | |
| try: | |
| response = requests.get(url, timeout=10) | |
| if response.status_code != 200: | |
| logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}") | |
| return None | |
| image = Image.open(io.BytesIO(response.content)).convert('RGB') | |
| # Apply segmentation | |
| mask = process_segmentation(image, segmenter) | |
| if mask is None: | |
| logger.warning(f"No valid mask found for image {metadata_id}") | |
| return None | |
| # Extract features | |
| features = extract_features(image, mask, model, preprocess_val, device) | |
| if features is None: | |
| logger.warning(f"Failed to extract features for image {metadata_id}") | |
| return None | |
| return features | |
| except Exception as e: | |
| logger.error(f"Error processing image {metadata_id}: {e}") | |
| return None | |
| def create_segmented_db(source_path, target_path, batch_size=100): | |
| """Create new segmented database from existing one""" | |
| try: | |
| logger.info("Loading models...") | |
| model, preprocess_val, segmenter, device = load_models() | |
| # Source DB 연결 | |
| source_client = chromadb.PersistentClient(path=source_path) | |
| source_collection = source_client.get_collection(name="clothes") | |
| # Target DB 생성 | |
| os.makedirs(target_path, exist_ok=True) | |
| target_client = chromadb.PersistentClient(path=target_path) | |
| try: | |
| target_client.delete_collection("clothes_segmented") | |
| except: | |
| pass | |
| target_collection = target_client.create_collection( | |
| name="clothes_segmented", | |
| metadata={"description": "Clothes collection with segmentation-based features"} | |
| ) | |
| # 전체 아이템 수 확인 | |
| all_items = source_collection.get(include=['metadatas']) | |
| total_items = len(all_items['metadatas']) | |
| logger.info(f"Found {total_items} items in source database") | |
| # 배치 처리를 위한 준비 | |
| successful_updates = 0 | |
| failed_updates = 0 | |
| # ThreadPoolExecutor 설정 | |
| max_workers = min(10, os.cpu_count() or 4) | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| # 전체 데이터를 배치로 나누어 처리 | |
| for batch_start in tqdm(range(0, total_items, batch_size), desc="Processing batches"): | |
| batch_end = min(batch_start + batch_size, total_items) | |
| batch_items = all_items['metadatas'][batch_start:batch_end] | |
| # 배치 내의 모든 이미지에 대한 future 생성 | |
| futures = [] | |
| for metadata in batch_items: | |
| if 'image_url' in metadata: | |
| future = executor.submit( | |
| download_and_process_image, | |
| metadata['image_url'], | |
| metadata.get('id', 'unknown'), | |
| model, preprocess_val, segmenter, device | |
| ) | |
| futures.append((metadata, future)) | |
| # 배치 결과 처리 | |
| batch_embeddings = [] | |
| batch_metadatas = [] | |
| batch_ids = [] | |
| for metadata, future in futures: | |
| try: | |
| features = future.result() | |
| if features is not None: | |
| batch_embeddings.append(features.tolist()) | |
| batch_metadatas.append(metadata) | |
| batch_ids.append(metadata.get('id', str(hash(metadata['image_url'])))) | |
| successful_updates += 1 | |
| else: | |
| failed_updates += 1 | |
| except Exception as e: | |
| logger.error(f"Error processing batch item: {e}") | |
| failed_updates += 1 | |
| continue | |
| # 배치 데이터 저장 | |
| if batch_embeddings: | |
| try: | |
| target_collection.add( | |
| embeddings=batch_embeddings, | |
| metadatas=batch_metadatas, | |
| ids=batch_ids | |
| ) | |
| logger.info(f"Added batch of {len(batch_embeddings)} items") | |
| except Exception as e: | |
| logger.error(f"Error adding batch to collection: {e}") | |
| failed_updates += len(batch_embeddings) | |
| successful_updates -= len(batch_embeddings) | |
| # 최종 결과 출력 | |
| logger.info(f"Database creation completed.") | |
| logger.info(f"Successfully processed: {successful_updates}") | |
| logger.info(f"Failed: {failed_updates}") | |
| logger.info(f"Total completion rate: {(successful_updates/total_items)*100:.2f}%") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Database creation error: {e}") | |
| return False | |
| if __name__ == "__main__": | |
| # 설정값 | |
| SOURCE_DB_PATH = "./clothesDB_11GmarketMusinsa" # 원본 DB 경로 | |
| TARGET_DB_PATH = "./clothesDB_11GmarketMusinsa_segmented" # 새로운 DB 경로 | |
| BATCH_SIZE = 50 # 한 번에 처리할 아이템 수 | |
| # DB 생성 실행 | |
| success = create_segmented_db(SOURCE_DB_PATH, TARGET_DB_PATH, BATCH_SIZE) | |
| if success: | |
| logger.info("Successfully created segmented database!") | |
| else: | |
| logger.error("Failed to create segmented database.") |