from __future__ import annotations import os import logging import time from dataclasses import dataclass from typing import List, Sequence import numpy as np from openai import OpenAI from langchain_core.embeddings import Embeddings logger = logging.getLogger(__name__) @dataclass class EmbeddingConfig: """Cấu hình cho embedding model.""" api_base_url: str = "https://api.siliconflow.com/v1" # SiliconFlow API model: str = "Qwen/Qwen3-Embedding-4B" # Model embedding dimension: int = 2048 # Số chiều vector batch_size: int = 16 # Số text mỗi batch _embed_config: EmbeddingConfig | None = None def get_embedding_config() -> EmbeddingConfig: """Lấy cấu hình embedding (singleton pattern).""" global _embed_config if _embed_config is None: _embed_config = EmbeddingConfig() return _embed_config class QwenEmbeddings(Embeddings): """Wrapper embedding model Qwen qua SiliconFlow API""" def __init__(self, config: EmbeddingConfig | None = None): """Khởi tạo embedding client.""" self.config = config or get_embedding_config() api_key = os.getenv("SILICONFLOW_API_KEY", "").strip() if not api_key: raise ValueError("Chưa đặt biến môi trường SILICONFLOW_API_KEY") self._client = OpenAI( api_key=api_key, base_url=self.config.api_base_url, ) logger.info(f"Đã khởi tạo QwenEmbeddings: {self.config.model}") def embed_query(self, text: str) -> List[float]: """Embed một câu query (dùng cho search).""" return self._embed_texts([text])[0] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed nhiều documents (dùng khi index).""" return self._embed_texts(texts) def _embed_texts(self, texts: Sequence[str]) -> List[List[float]]: """Embed danh sách texts theo batch với retry logic.""" if not texts: return [] all_embeddings: List[List[float]] = [] batch_size = self.config.batch_size max_retries = 3 # Xử lý theo batch for i in range(0, len(texts), batch_size): batch = list(texts[i:i + batch_size]) # Retry logic cho rate limit for attempt in range(max_retries): try: response = self._client.embeddings.create( model=self.config.model, input=batch, ) for item in response.data: all_embeddings.append(item.embedding) break except Exception as e: # Nếu bị rate limit -> đợi rồi thử lại if "rate" in str(e).lower() and attempt < max_retries - 1: wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s logger.warning(f"Bị rate limit, đợi {wait_time}s...") time.sleep(wait_time) else: raise return all_embeddings def embed_texts_np(self, texts: Sequence[str]) -> np.ndarray: """Embed texts và trả về numpy array (tiện cho tính toán).""" return np.asarray(self._embed_texts(list(texts)), dtype=np.float32) # Alias để tương thích ngược SiliconFlowConfig = EmbeddingConfig get_config = get_embedding_config