|
|
from __future__ import annotations |
|
|
import os |
|
|
import logging |
|
|
import time |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Sequence |
|
|
import numpy as np |
|
|
from openai import OpenAI |
|
|
from langchain_core.embeddings import Embeddings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EmbeddingConfig: |
|
|
"""Cấu hình cho embedding model.""" |
|
|
api_base_url: str = "https://api.siliconflow.com/v1" |
|
|
model: str = "Qwen/Qwen3-Embedding-4B" |
|
|
dimension: int = 2048 |
|
|
batch_size: int = 16 |
|
|
|
|
|
|
|
|
_embed_config: EmbeddingConfig | None = None |
|
|
|
|
|
|
|
|
def get_embedding_config() -> EmbeddingConfig: |
|
|
"""Lấy cấu hình embedding (singleton pattern).""" |
|
|
global _embed_config |
|
|
if _embed_config is None: |
|
|
_embed_config = EmbeddingConfig() |
|
|
return _embed_config |
|
|
|
|
|
|
|
|
class QwenEmbeddings(Embeddings): |
|
|
"""Wrapper embedding model Qwen qua SiliconFlow API""" |
|
|
|
|
|
def __init__(self, config: EmbeddingConfig | None = None): |
|
|
"""Khởi tạo embedding client.""" |
|
|
self.config = config or get_embedding_config() |
|
|
|
|
|
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip() |
|
|
if not api_key: |
|
|
raise ValueError("Chưa đặt biến môi trường SILICONFLOW_API_KEY") |
|
|
|
|
|
self._client = OpenAI( |
|
|
api_key=api_key, |
|
|
base_url=self.config.api_base_url, |
|
|
) |
|
|
logger.info(f"Đã khởi tạo QwenEmbeddings: {self.config.model}") |
|
|
|
|
|
def embed_query(self, text: str) -> List[float]: |
|
|
"""Embed một câu query (dùng cho search).""" |
|
|
return self._embed_texts([text])[0] |
|
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]: |
|
|
"""Embed nhiều documents (dùng khi index).""" |
|
|
return self._embed_texts(texts) |
|
|
|
|
|
def _embed_texts(self, texts: Sequence[str]) -> List[List[float]]: |
|
|
"""Embed danh sách texts theo batch với retry logic.""" |
|
|
if not texts: |
|
|
return [] |
|
|
|
|
|
all_embeddings: List[List[float]] = [] |
|
|
batch_size = self.config.batch_size |
|
|
max_retries = 3 |
|
|
|
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = list(texts[i:i + batch_size]) |
|
|
|
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
response = self._client.embeddings.create( |
|
|
model=self.config.model, |
|
|
input=batch, |
|
|
) |
|
|
for item in response.data: |
|
|
all_embeddings.append(item.embedding) |
|
|
break |
|
|
except Exception as e: |
|
|
|
|
|
if "rate" in str(e).lower() and attempt < max_retries - 1: |
|
|
wait_time = 2 ** attempt |
|
|
logger.warning(f"Bị rate limit, đợi {wait_time}s...") |
|
|
time.sleep(wait_time) |
|
|
else: |
|
|
raise |
|
|
|
|
|
return all_embeddings |
|
|
|
|
|
def embed_texts_np(self, texts: Sequence[str]) -> np.ndarray: |
|
|
"""Embed texts và trả về numpy array (tiện cho tính toán).""" |
|
|
return np.asarray(self._embed_texts(list(texts)), dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
SiliconFlowConfig = EmbeddingConfig |
|
|
get_config = get_embedding_config |
|
|
|