File size: 3,638 Bytes
9348624 f9bc137 b91b0a5 9348624 f9bc137 9348624 b91b0a5 f9bc137 9348624 f9bc137 4ff2e4d b91b0a5 f9bc137 4ff2e4d f9bc137 b91b0a5 4ff2e4d b91b0a5 4ff2e4d 9681056 4ff2e4d f9bc137 b91b0a5 4ff2e4d b91b0a5 4ff2e4d f9bc137 b91b0a5 f9bc137 b91b0a5 f9bc137 b91b0a5 f9bc137 b91b0a5 f9bc137 b91b0a5 f9bc137 4ff2e4d f9bc137 b91b0a5 f9bc137 b91b0a5 4ff2e4d b91b0a5 4ff2e4d b91b0a5 4ff2e4d f9bc137 b91b0a5 f9bc137 b91b0a5 4ff2e4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
from __future__ import annotations
import os
import logging
import time
from dataclasses import dataclass
from typing import List, Sequence
import numpy as np
from openai import OpenAI
from langchain_core.embeddings import Embeddings
logger = logging.getLogger(__name__)
@dataclass
class EmbeddingConfig:
"""Cấu hình cho embedding model."""
api_base_url: str = "https://api.siliconflow.com/v1" # SiliconFlow API
model: str = "Qwen/Qwen3-Embedding-4B" # Model embedding
dimension: int = 2048 # Số chiều vector
batch_size: int = 16 # Số text mỗi batch
_embed_config: EmbeddingConfig | None = None
def get_embedding_config() -> EmbeddingConfig:
"""Lấy cấu hình embedding (singleton pattern)."""
global _embed_config
if _embed_config is None:
_embed_config = EmbeddingConfig()
return _embed_config
class QwenEmbeddings(Embeddings):
"""Wrapper embedding model Qwen qua SiliconFlow API"""
def __init__(self, config: EmbeddingConfig | None = None):
"""Khởi tạo embedding client."""
self.config = config or get_embedding_config()
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
if not api_key:
raise ValueError("Chưa đặt biến môi trường SILICONFLOW_API_KEY")
self._client = OpenAI(
api_key=api_key,
base_url=self.config.api_base_url,
)
logger.info(f"Đã khởi tạo QwenEmbeddings: {self.config.model}")
def embed_query(self, text: str) -> List[float]:
"""Embed một câu query (dùng cho search)."""
return self._embed_texts([text])[0]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed nhiều documents (dùng khi index)."""
return self._embed_texts(texts)
def _embed_texts(self, texts: Sequence[str]) -> List[List[float]]:
"""Embed danh sách texts theo batch với retry logic."""
if not texts:
return []
all_embeddings: List[List[float]] = []
batch_size = self.config.batch_size
max_retries = 3
# Xử lý theo batch
for i in range(0, len(texts), batch_size):
batch = list(texts[i:i + batch_size])
# Retry logic cho rate limit
for attempt in range(max_retries):
try:
response = self._client.embeddings.create(
model=self.config.model,
input=batch,
)
for item in response.data:
all_embeddings.append(item.embedding)
break
except Exception as e:
# Nếu bị rate limit -> đợi rồi thử lại
if "rate" in str(e).lower() and attempt < max_retries - 1:
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
logger.warning(f"Bị rate limit, đợi {wait_time}s...")
time.sleep(wait_time)
else:
raise
return all_embeddings
def embed_texts_np(self, texts: Sequence[str]) -> np.ndarray:
"""Embed texts và trả về numpy array (tiện cho tính toán)."""
return np.asarray(self._embed_texts(list(texts)), dtype=np.float32)
# Alias để tương thích ngược
SiliconFlowConfig = EmbeddingConfig
get_config = get_embedding_config
|