File size: 3,638 Bytes
9348624
f9bc137
 
b91b0a5
9348624
 
 
f9bc137
9348624
b91b0a5
f9bc137
9348624
 
f9bc137
4ff2e4d
b91b0a5
 
 
 
 
f9bc137
 
4ff2e4d
f9bc137
b91b0a5
4ff2e4d
b91b0a5
4ff2e4d
 
9681056
4ff2e4d
f9bc137
 
 
b91b0a5
 
4ff2e4d
b91b0a5
4ff2e4d
f9bc137
 
 
b91b0a5
f9bc137
 
 
 
 
b91b0a5
f9bc137
 
b91b0a5
f9bc137
 
 
b91b0a5
f9bc137
 
 
b91b0a5
f9bc137
 
 
 
4ff2e4d
 
f9bc137
b91b0a5
f9bc137
 
 
b91b0a5
4ff2e4d
 
 
 
 
 
 
 
 
 
b91b0a5
4ff2e4d
b91b0a5
 
4ff2e4d
 
 
f9bc137
 
 
 
b91b0a5
f9bc137
 
 
b91b0a5
4ff2e4d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from __future__ import annotations
import os
import logging
import time
from dataclasses import dataclass
from typing import List, Sequence
import numpy as np
from openai import OpenAI
from langchain_core.embeddings import Embeddings

logger = logging.getLogger(__name__)


@dataclass
class EmbeddingConfig:
    """Cấu hình cho embedding model."""
    api_base_url: str = "https://api.siliconflow.com/v1"  # SiliconFlow API
    model: str = "Qwen/Qwen3-Embedding-4B"                # Model embedding
    dimension: int = 2048                                  # Số chiều vector
    batch_size: int = 16                                   # Số text mỗi batch


_embed_config: EmbeddingConfig | None = None


def get_embedding_config() -> EmbeddingConfig:
    """Lấy cấu hình embedding (singleton pattern)."""
    global _embed_config
    if _embed_config is None:
        _embed_config = EmbeddingConfig()
    return _embed_config


class QwenEmbeddings(Embeddings):
    """Wrapper embedding model Qwen qua SiliconFlow API"""
    
    def __init__(self, config: EmbeddingConfig | None = None):
        """Khởi tạo embedding client."""
        self.config = config or get_embedding_config()
        
        api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
        if not api_key:
            raise ValueError("Chưa đặt biến môi trường SILICONFLOW_API_KEY")
        
        self._client = OpenAI(
            api_key=api_key,
            base_url=self.config.api_base_url,
        )
        logger.info(f"Đã khởi tạo QwenEmbeddings: {self.config.model}")
    
    def embed_query(self, text: str) -> List[float]:
        """Embed một câu query (dùng cho search)."""
        return self._embed_texts([text])[0]
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed nhiều documents (dùng khi index)."""
        return self._embed_texts(texts)
    
    def _embed_texts(self, texts: Sequence[str]) -> List[List[float]]:
        """Embed danh sách texts theo batch với retry logic."""
        if not texts:
            return []
        
        all_embeddings: List[List[float]] = []
        batch_size = self.config.batch_size
        max_retries = 3
        
        # Xử lý theo batch
        for i in range(0, len(texts), batch_size):
            batch = list(texts[i:i + batch_size])
            
            # Retry logic cho rate limit
            for attempt in range(max_retries):
                try:
                    response = self._client.embeddings.create(
                        model=self.config.model,
                        input=batch,
                    )
                    for item in response.data:
                        all_embeddings.append(item.embedding)
                    break
                except Exception as e:
                    # Nếu bị rate limit -> đợi rồi thử lại
                    if "rate" in str(e).lower() and attempt < max_retries - 1:
                        wait_time = 2 ** attempt  # Exponential backoff: 1s, 2s, 4s
                        logger.warning(f"Bị rate limit, đợi {wait_time}s...")
                        time.sleep(wait_time)
                    else:
                        raise
        
        return all_embeddings
    
    def embed_texts_np(self, texts: Sequence[str]) -> np.ndarray:
        """Embed texts và trả về numpy array (tiện cho tính toán)."""
        return np.asarray(self._embed_texts(list(texts)), dtype=np.float32)


# Alias để tương thích ngược
SiliconFlowConfig = EmbeddingConfig
get_config = get_embedding_config