File size: 1,153 Bytes
44b27f4
9f911b3
44b27f4
9f911b3
 
 
 
44b27f4
9f911b3
44b27f4
9f911b3
 
44b27f4
 
 
9f911b3
44b27f4
 
9f911b3
 
44b27f4
 
 
 
 
 
 
 
 
 
 
 
0b6760c
9f911b3
 
 
44b27f4
9f911b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Embedding generation via HuggingFace Inference API (no local torch needed)."""

import os
import logging

logger = logging.getLogger(__name__)

MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
EMBEDDING_DIM = 384
BATCH_SIZE = 96  # HF Inference API batch limit


def _get_client():
    """Get HuggingFace InferenceClient."""
    from huggingface_hub import InferenceClient

    token = os.environ.get("HF_TOKEN")
    return InferenceClient(token=token)


def generate(texts: list[str]) -> list[list[float]]:
    """Encode texts into embedding vectors via HF Inference API."""
    client = _get_client()
    all_embeddings = []

    # Process in batches
    for i in range(0, len(texts), BATCH_SIZE):
        batch = texts[i : i + BATCH_SIZE]
        result = client.feature_extraction(batch, model=MODEL_NAME)
        all_embeddings.extend(result)

    logger.info("Generated %d embeddings via HF Inference API", len(all_embeddings))
    return [[float(x) for x in emb] for emb in all_embeddings]


def generate_query(query: str) -> list[float]:
    """Embed a single query string (for future RAG Engine use)."""
    return generate([query])[0]