Spaces:
Sleeping
Sleeping
| import logging | |
| from sentence_transformers import SentenceTransformer | |
| import chromadb | |
| from chromadb.config import Settings | |
| import uuid | |
| import os | |
| from config import EMBEDDING_MODEL, CHROMA_PERSIST_DIRECTORY, COLLECTION_NAME | |
| logger = logging.getLogger(__name__) | |
| _embedding_model_instance = None | |
| def get_embedding_model(): | |
| """Kiểm tra và khởi tạo embedding đảm bảo chỉ khởi tạo một lần""" | |
| global _embedding_model_instance | |
| if _embedding_model_instance is None: | |
| logger.info("Khởi tạo EmbeddingModel instance lần đầu") | |
| _embedding_model_instance = EmbeddingModel() | |
| else: | |
| logger.debug("Sử dụng EmbeddingModel instance đã có") | |
| return _embedding_model_instance | |
| class EmbeddingModel: | |
| def __init__(self): | |
| """Khởi tạo embedding model và ChromaDB client""" | |
| logger.info(f"Đang khởi tạo embedding model: {EMBEDDING_MODEL}") | |
| try: | |
| # Khởi tạo sentence transformer với trust_remote_code=True | |
| self.model = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True) | |
| logger.info("Đã tải sentence transformer model") | |
| except Exception as e: | |
| logger.error(f"Lỗi khởi tạo model: {e}") | |
| # Thử với cache_folder explicit | |
| cache_dir = os.getenv('SENTENCE_TRANSFORMERS_HOME', '/app/.cache/sentence-transformers') | |
| self.model = SentenceTransformer(EMBEDDING_MODEL, cache_folder=cache_dir, trust_remote_code=True) | |
| logger.info("Đã tải sentence transformer model với cache folder explicit") | |
| # SỬA: Khai báo biến persist_directory local để tránh lỗi scope | |
| persist_directory = CHROMA_PERSIST_DIRECTORY | |
| # Đảm bảo thư mục ChromaDB tồn tại và có quyền ghi | |
| try: | |
| os.makedirs(persist_directory, exist_ok=True) | |
| # Test ghi file để kiểm tra permission | |
| test_file = os.path.join(persist_directory, 'test_permission.tmp') | |
| with open(test_file, 'w') as f: | |
| f.write('test') | |
| os.remove(test_file) | |
| logger.info(f"Thư mục ChromaDB đã sẵn sàng: {persist_directory}") | |
| except Exception as e: | |
| logger.error(f"Lỗi tạo/kiểm tra thư mục ChromaDB: {e}") | |
| # Fallback to /tmp directory | |
| import tempfile | |
| persist_directory = os.path.join(tempfile.gettempdir(), 'chroma_db') | |
| os.makedirs(persist_directory, exist_ok=True) | |
| logger.warning(f"Sử dụng thư mục tạm thời: {persist_directory}") | |
| # Khởi tạo ChromaDB client với persistent storage | |
| try: | |
| self.chroma_client = chromadb.PersistentClient( | |
| path=persist_directory, | |
| settings=Settings( | |
| anonymized_telemetry=False, | |
| allow_reset=True | |
| ) | |
| ) | |
| logger.info(f"Đã kết nối ChromaDB tại: {persist_directory}") | |
| except Exception as e: | |
| logger.error(f"Lỗi kết nối ChromaDB: {e}") | |
| # Fallback to in-memory client | |
| logger.warning("Fallback to in-memory ChromaDB client") | |
| self.chroma_client = chromadb.Client() | |
| # Lấy hoặc tạo collection với cosine similarity | |
| try: | |
| self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME) | |
| logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với {self.collection.count()} items") | |
| except Exception: | |
| logger.info(f"Collection '{COLLECTION_NAME}' không tồn tại, tạo mới với cosine similarity...") | |
| self.collection = self.chroma_client.create_collection( | |
| name=COLLECTION_NAME, | |
| metadata={ | |
| "hnsw:space": "cosine", # Cosine distance | |
| "hnsw:M": 16, # Optimize for accuracy | |
| "hnsw:construction_ef": 100 | |
| } | |
| ) | |
| logger.info(f"Đã tạo collection mới với cosine similarity: {COLLECTION_NAME}") | |
| def _initialize_collection(self): | |
| """Khởi tạo collection với cosine similarity""" | |
| try: | |
| # Kiểm tra xem collection đã tồn tại chưa | |
| existing_collections = [col.name for col in self.chroma_client.list_collections()] | |
| if COLLECTION_NAME in existing_collections: | |
| self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME) | |
| # Kiểm tra distance function hiện tại | |
| current_metadata = self.collection.metadata or {} | |
| current_space = current_metadata.get("hnsw:space", "l2") | |
| if current_space != "cosine": | |
| logger.warning(f"Collection hiện tại đang dùng {current_space}, cần migration sang cosine") | |
| if self.collection.count() > 0: | |
| self._migrate_to_cosine() | |
| else: | |
| # Collection trống, xóa và tạo lại | |
| self.chroma_client.delete_collection(name=COLLECTION_NAME) | |
| self._create_cosine_collection() | |
| else: | |
| logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với cosine similarity, {self.collection.count()} items") | |
| else: | |
| # Collection chưa tồn tại, tạo mới với cosine | |
| self._create_cosine_collection() | |
| except Exception as e: | |
| logger.error(f"Lỗi khởi tạo collection: {e}") | |
| # Fallback: tạo collection mới | |
| self._create_cosine_collection() | |
| def _create_cosine_collection(self): | |
| """Tạo collection mới với cosine similarity""" | |
| try: | |
| self.collection = self.chroma_client.create_collection( | |
| name=COLLECTION_NAME, | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| logger.info(f"Đã tạo collection mới với cosine similarity: {COLLECTION_NAME}") | |
| except Exception as e: | |
| logger.error(f"Lỗi tạo collection với cosine: {e}") | |
| # Fallback về collection mặc định | |
| self.collection = self.chroma_client.get_or_create_collection(name=COLLECTION_NAME) | |
| logger.warning("Đã fallback về collection mặc định (có thể dùng L2)") | |
| def _migrate_to_cosine(self): | |
| """Migration collection từ L2 sang cosine""" | |
| try: | |
| logger.info("Bắt đầu migration collection sang cosine similarity...") | |
| # Backup toàn bộ data | |
| all_data = self.collection.get( | |
| include=['documents', 'metadatas', 'embeddings'], | |
| limit=self.collection.count() | |
| ) | |
| if not all_data['documents']: | |
| logger.info("Collection trống, chỉ cần tạo lại") | |
| self.chroma_client.delete_collection(name=COLLECTION_NAME) | |
| self._create_cosine_collection() | |
| return | |
| # Xóa collection cũ và tạo mới với cosine | |
| self.chroma_client.delete_collection(name=COLLECTION_NAME) | |
| self._create_cosine_collection() | |
| # Restore data theo batch | |
| documents = all_data['documents'] | |
| metadatas = all_data['metadatas'] | |
| embeddings = all_data['embeddings'] | |
| ids = all_data['ids'] | |
| batch_size = 100 | |
| total_items = len(documents) | |
| for i in range(0, total_items, batch_size): | |
| batch_docs = documents[i:i + batch_size] | |
| batch_metas = metadatas[i:i + batch_size] if metadatas else None | |
| batch_embeds = embeddings[i:i + batch_size] if embeddings else None | |
| batch_ids = ids[i:i + batch_size] | |
| if batch_embeds: | |
| # Có embeddings sẵn, dùng luôn | |
| self.collection.add( | |
| documents=batch_docs, | |
| metadatas=batch_metas, | |
| embeddings=batch_embeds, | |
| ids=batch_ids | |
| ) | |
| else: | |
| # Tính lại embeddings | |
| new_embeddings = self.encode(batch_docs, is_query=False) | |
| self.collection.add( | |
| documents=batch_docs, | |
| metadatas=batch_metas, | |
| embeddings=new_embeddings, | |
| ids=batch_ids | |
| ) | |
| logger.info(f"Migration progress: {min(i + batch_size, total_items)}/{total_items}") | |
| logger.info(f"Migration hoàn thành! Đã chuyển {total_items} items sang cosine similarity") | |
| except Exception as e: | |
| logger.error(f"Lỗi migration: {e}") | |
| # Tạo collection mới nếu migration thất bại | |
| self._create_cosine_collection() | |
| def test_embedding_quality(self): | |
| try: | |
| # Test cases | |
| test_cases = [ | |
| ("query: Tháp dinh dưỡng cho trẻ", "passage: Tháp dinh dưỡng cho trẻ từ 6-11 tuổi"), | |
| ("query: dinh dưỡng", "passage: dinh dưỡng cho học sinh"), | |
| ("query: xin chào", "passage: Tháp dinh dưỡng cho trẻ") | |
| ] | |
| for query_text, doc_text in test_cases: | |
| # Encode | |
| query_emb = self.model.encode([query_text], normalize_embeddings=True)[0] | |
| doc_emb = self.model.encode([doc_text], normalize_embeddings=True)[0] | |
| # Calculate cosine similarity manually | |
| import numpy as np | |
| similarity = np.dot(query_emb, doc_emb) | |
| logger.info(f"Query: {query_text}") | |
| logger.info(f"Doc: {doc_text}") | |
| logger.info(f"Similarity: {similarity:.3f}") | |
| logger.info(f"Query norm: {np.linalg.norm(query_emb):.3f}") | |
| logger.info(f"Doc norm: {np.linalg.norm(doc_emb):.3f}") | |
| logger.info("-" * 50) | |
| except Exception as e: | |
| logger.error(f"Test embedding error: {e}") | |
| def _add_prefix_to_text(self, text, is_query=True): | |
| # Clean text trước | |
| text = text.strip() | |
| # Kiểm tra xem text đã có prefix chưa | |
| if text.startswith(('query:', 'passage:')): | |
| return text | |
| # Thêm prefix phù hợp | |
| if is_query: | |
| return f"query: {text}" | |
| else: | |
| return f"passage: {text}" | |
| def encode(self, texts, is_query=True): | |
| """ | |
| Encode văn bản thành embeddings với proper normalization | |
| """ | |
| try: | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| # Thêm prefix cho texts (QUAN TRỌNG cho multilingual-e5-base) | |
| processed_texts = [self._add_prefix_to_text(text, is_query) for text in texts] | |
| logger.debug(f"Đang encode {len(processed_texts)} văn bản với prefix") | |
| logger.debug(f"Sample processed text: {processed_texts[0][:100]}...") | |
| # Encode với normalize_embeddings=True (QUAN TRỌNG!) | |
| embeddings = self.model.encode( | |
| processed_texts, | |
| show_progress_bar=False, | |
| normalize_embeddings=True # ✅ THÊM DÒNG NÀY | |
| ) | |
| # Double-check normalization | |
| import numpy as np | |
| for i, emb in enumerate(embeddings[:2]): # Check first 2 embeddings | |
| norm = np.linalg.norm(emb) | |
| logger.debug(f"Embedding {i} norm: {norm}") | |
| if abs(norm - 1.0) > 0.01: | |
| logger.warning(f"Embedding {i} not properly normalized: norm = {norm}") | |
| return embeddings.tolist() | |
| except Exception as e: | |
| logger.error(f"Lỗi encode văn bản: {e}") | |
| raise | |
| def search(self, query, top_k=5, age_filter=None): | |
| """Tìm kiếm văn bản tương tự trong ChromaDB""" | |
| try: | |
| query_embedding = self.encode(query, is_query=True)[0] | |
| where_clause = None | |
| if age_filter: | |
| where_clause = { | |
| "$and": [ | |
| {"age_min": {"$lte": age_filter}}, | |
| {"age_max": {"$gte": age_filter}} | |
| ] | |
| } | |
| print(f"🔍 AGE FILTER: Tìm kiếm cho tuổi {age_filter}") | |
| print(f"🔍 WHERE CLAUSE: {where_clause}") | |
| else: | |
| print(f"⚠️ KHÔNG CÓ AGE FILTER - Tìm tất cả chunks") | |
| search_results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=top_k, | |
| where=where_clause, | |
| include=['documents', 'metadatas', 'distances'] | |
| ) | |
| print(f"\n{'='*60}") | |
| print(f"📊 CHROMADB SEARCH RESULTS") | |
| print(f"{'='*60}") | |
| print(f"Query: {query}") | |
| print(f"Age filter: {age_filter}") | |
| print(f"Found {len(search_results['documents'][0]) if search_results['documents'] else 0} chunks") | |
| print(f"{'='*60}") | |
| if not search_results or not search_results['documents']: | |
| logger.warning("Không tìm thấy kết quả nào") | |
| return [] | |
| results = [] | |
| documents = search_results['documents'][0] | |
| metadatas = search_results['metadatas'][0] | |
| distances = search_results['distances'][0] | |
| for i, (doc, metadata, distance) in enumerate(zip(documents, metadatas, distances)): | |
| chunk_id = metadata.get('chunk_id', f'chunk_{i}') | |
| title = metadata.get('title', 'No title') | |
| age_range = metadata.get('age_range', 'Unknown') | |
| age_min = metadata.get('age_min', 'N/A') | |
| age_max = metadata.get('age_max', 'N/A') | |
| content_type = metadata.get('content_type', 'text') | |
| chapter = metadata.get('chapter', 'Unknown') | |
| similarity = round(1 - distance, 3) | |
| results.append({ | |
| 'document': doc, | |
| 'metadata': metadata or {}, | |
| 'distance': distance, | |
| 'similarity': similarity, | |
| 'rank': i + 1 | |
| }) | |
| print(f"\n{'='*60}") | |
| logger.info(f"Tim thay {len(results)} ket qua cho query") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Loi tim kiem: {e}") | |
| return [] | |
| def add_documents(self, documents, metadatas=None, ids=None): | |
| """Thêm documents vào ChromaDB""" | |
| try: | |
| if not documents: | |
| logger.warning("Không có documents để thêm") | |
| return False | |
| if not ids: | |
| ids = [str(uuid.uuid4()) for _ in documents] | |
| if not metadatas: | |
| metadatas = [{} for _ in documents] | |
| logger.info(f"Đang thêm {len(documents)} documents vào ChromaDB") | |
| embeddings = self.encode(documents, is_query=False) | |
| self.collection.add( | |
| embeddings=embeddings, | |
| documents=documents, | |
| metadatas=metadatas, | |
| ids=ids | |
| ) | |
| logger.info(f"Đã thêm thành công {len(documents)} documents") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Lỗi thêm documents: {e}") | |
| return False | |
| def index_chunks(self, chunks): | |
| """Index các chunks dữ liệu vào ChromaDB""" | |
| try: | |
| if not chunks: | |
| logger.warning("Không có chunks để index") | |
| return False | |
| documents = [] | |
| metadatas = [] | |
| ids = [] | |
| for chunk in chunks: | |
| if not chunk.get('content'): | |
| logger.warning(f"Chunk thiếu content: {chunk}") | |
| continue | |
| documents.append(chunk['content']) | |
| metadata = chunk.get('metadata', {}) | |
| metadatas.append(metadata) | |
| chunk_id = chunk.get('id') or str(uuid.uuid4()) | |
| ids.append(chunk_id) | |
| if not documents: | |
| logger.warning("Không có documents hợp lệ để index") | |
| return False | |
| batch_size = 100 | |
| total_batches = (len(documents) + batch_size - 1) // batch_size | |
| for i in range(0, len(documents), batch_size): | |
| batch_docs = documents[i:i + batch_size] | |
| batch_metas = metadatas[i:i + batch_size] | |
| batch_ids = ids[i:i + batch_size] | |
| batch_num = (i // batch_size) + 1 | |
| logger.info(f"Đang xử lý batch {batch_num}/{total_batches} ({len(batch_docs)} items)") | |
| success = self.add_documents(batch_docs, batch_metas, batch_ids) | |
| if not success: | |
| logger.error(f"Lỗi xử lý batch {batch_num}") | |
| return False | |
| logger.info(f"Đã index thành công {len(documents)} chunks") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Lỗi index chunks: {e}") | |
| return False | |
| def count(self): | |
| """Đếm số lượng documents trong collection""" | |
| try: | |
| return self.collection.count() | |
| except Exception as e: | |
| logger.error(f"Lỗi đếm documents: {e}") | |
| return 0 | |
| def delete_collection(self): | |
| """Xóa collection hiện tại""" | |
| try: | |
| logger.warning(f"Đang xóa collection: {COLLECTION_NAME}") | |
| self.chroma_client.delete_collection(name=COLLECTION_NAME) | |
| # Tạo lại collection với cosine similarity | |
| self._create_cosine_collection() | |
| logger.info("Đã tạo lại collection mới với cosine similarity") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Lỗi xóa collection: {e}") | |
| return False | |
| def get_collection_info(self): | |
| """Lấy thông tin về collection và distance function""" | |
| try: | |
| metadata = self.collection.metadata or {} | |
| distance_func = metadata.get("hnsw:space", "l2") | |
| return { | |
| 'collection_name': COLLECTION_NAME, | |
| 'distance_function': distance_func, | |
| 'total_documents': self.count(), | |
| 'metadata': metadata | |
| } | |
| except Exception as e: | |
| logger.error(f"Lỗi lấy collection info: {e}") | |
| return {'error': str(e)} | |
| def verify_cosine_similarity(self): | |
| """Kiểm tra và xác nhận đang sử dụng cosine similarity""" | |
| try: | |
| info = self.get_collection_info() | |
| distance_func = info.get('distance_function', 'unknown') | |
| logger.info(f"Collection đang sử dụng distance function: {distance_func}") | |
| if distance_func == "cosine": | |
| logger.info("Xác nhận: Đang sử dụng cosine similarity") | |
| return True | |
| else: | |
| logger.warning(f"Cảnh báo: Đang sử dụng {distance_func}, không phải cosine") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Lỗi verify cosine: {e}") | |
| return False | |
| def get_stats(self): | |
| """Lấy thống kê về collection""" | |
| try: | |
| total_count = self.count() | |
| collection_info = self.get_collection_info() | |
| sample_results = self.collection.get(limit=min(100, total_count)) | |
| content_types = {} | |
| chapters = {} | |
| age_groups = {} | |
| if sample_results and sample_results.get('metadatas'): | |
| for metadata in sample_results['metadatas']: | |
| if not metadata: | |
| continue | |
| content_type = metadata.get('content_type', 'unknown') | |
| content_types[content_type] = content_types.get(content_type, 0) + 1 | |
| chapter = metadata.get('chapter', 'unknown') | |
| chapters[chapter] = chapters.get(chapter, 0) + 1 | |
| age_group = metadata.get('age_group', 'unknown') | |
| age_groups[age_group] = age_groups.get(age_group, 0) + 1 | |
| return { | |
| 'total_documents': total_count, | |
| 'content_types': content_types, | |
| 'chapters': chapters, | |
| 'age_groups': age_groups, | |
| 'collection_name': COLLECTION_NAME, | |
| 'embedding_model': EMBEDDING_MODEL, | |
| 'distance_function': collection_info.get('distance_function', 'unknown'), | |
| 'using_cosine_similarity': collection_info.get('distance_function') == 'cosine' | |
| } | |
| except Exception as e: | |
| logger.error(f"Lỗi lấy stats: {e}") | |
| return { | |
| 'total_documents': 0, | |
| 'error': str(e) | |
| } |