File size: 8,405 Bytes
1bc3f18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
from stores.llm.LLMInterface import LLMInterface
import logging
import requests
import re
import os


class HuggingFaceProvider(LLMInterface):
    def __init__(self, url: str = None, model: str = None,
                 default_input_max_characters: int = 1000,
                 default_generation_max_output_tokens: int = 1000,
                 default_generation_temperature: float = 0.1, api_key: str = None):
        # Supports both Inference API (serverless) and Inference Endpoints (dedicated)
        self.url = url or "https://router.huggingface.co"
        self.api_key = api_key or os.getenv("HF_API_KEY")
        self.model = model
        self.generation_model_id = None

        self.embedding_model = None
        self.embedding_model_id = None
        self.embedding_size = None

        self.default_input_max_characters = default_input_max_characters
        self.default_generation_max_output_tokens = default_generation_max_output_tokens
        self.default_generation_temperature = default_generation_temperature
        self.logger = logging.getLogger(__name__)

    def set_generation_model(self, model_id: str):
        if model_id:
            self.model = model_id

    def set_embedding_model(self, model_id: str, embedding_size: int):
        if model_id:
            self.embedding_model = model_id
            self.embedding_size = embedding_size
            self.embedding_model_id = model_id

    def process_text(self, text: str):
        if not text:
            return ""
        return str(text).strip()

    def generate_text(self, prompt: str, chat_history: list = None,
                      max_output_tokens: int = None, temperature: float = None):
        try:
            chat_history = chat_history or []
            clean_prompt = self.process_text(prompt)

            messages = []
            for entry in chat_history:
                messages.append({
                    "role": entry.get("role", "user"),
                    "content": entry.get("content", "")
                })
            messages.append({"role": "user", "content": clean_prompt})

            payload = {
                "model": self.model,
                "messages": messages,
                "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
                "temperature": float(temperature or self.default_generation_temperature),
            }

            # HF Inference API (serverless): /v1/chat/completions (OpenAI-compatible)
            url = self.url.rstrip("/") + "/v1/chat/completions"
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json",
            }

            resp = requests.post(url, json=payload, headers=headers, timeout=6000)
            if resp.status_code != 200:
                self.logger.error("HuggingFace generate failed: %s %s", resp.status_code, resp.text)
                return None

            data = resp.json()

            try:
                generated_text = data["choices"][0]["message"]["content"].strip()
            except (KeyError, IndexError, TypeError):
                self.logger.error("Unexpected HuggingFace response structure: %s", data)
                return None

            if not generated_text:
                return None

            usage = data.get("usage", {})
            return {
                "model": data.get("model"),
                "response": generated_text,
                "tokens_generated": usage.get("completion_tokens"),
                "total_duration_ms": None,
                "prompt_eval_tokens": usage.get("prompt_tokens"),
            }

        except Exception as e:
            self.logger.exception("Error in HuggingFaceProvider.generate_text: %s", e)
            return None

    def embed_text(self, text: str, document_type: str = None):
        try:
            if not self.embedding_model:
                self.logger.error("Embedding model is not set before calling embed_text()")
                return None

            clean_text = self.process_text(text)
            print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'")
            if not clean_text:
                return []

            payload = {"inputs": clean_text}

            # Feature-extraction endpoint per model
            url = f"https://router.huggingface.co/hf-inference/models/{self.embedding_model}/pipeline/feature-extraction"
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json",
            }

            resp = requests.post(url, json=payload, headers=headers, timeout=200)
            if resp.status_code != 200:
                print(f"[ERROR] HuggingFace embedding failed: {resp.status_code} {resp.text}")
                return None

            data = resp.json()

            # HF returns a nested list: [[vector]] for single input
            embedding = None
            if isinstance(data, list):
                if len(data) > 0 and isinstance(data[0], list):
                    embedding = data[0]  # [[float, ...]] -> [float, ...]
                elif len(data) > 0 and isinstance(data[0], float):
                    embedding = data     # [float, ...] already flat
            elif isinstance(data, dict) and "embedding" in data:
                embedding = data["embedding"]

            if embedding is not None:
                print(f"[DEBUG] Embedding length: {len(embedding)}")
                return embedding

            print("[WARNING] 'embedding' key not found in response JSON")
            return None

        except Exception as e:
            print(f"[EXCEPTION] Error in HuggingFaceProvider.embed_text: {e}")
            return None

    def construct_prompt(self, prompt: str, role: str):
        return {
            "role": role,
            "content": self.process_text(prompt)
        }

    def embed_text_batch(self, texts: list[str], batch_size: int = 32):
        self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")

        if not self.embedding_model:
            self.logger.error("Embedding model not set")
            return None

        all_embeddings = []

        url = f"https://router.huggingface.co/hf-inference/models/{self.embedding_model}/pipeline/feature-extraction"
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }

        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            clean_batch = [self.process_text(t) for t in batch if t]

            print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}")

            payload = {"inputs": clean_batch}

            resp = requests.post(url, json=payload, headers=headers, timeout=200)
            if resp.status_code != 200:
                self.logger.error("HuggingFace embedding failed: %s %s", resp.status_code, resp.text)
                return None

            data = resp.json()

            # Batch response: [[vec1], [vec2], ...] or [[f,f,...], [f,f,...]]
            embeddings = None
            if isinstance(data, list) and len(data) > 0:
                if isinstance(data[0], list):
                    embeddings = data
                elif isinstance(data[0], float):
                    embeddings = [data]  # single vector returned flat

            if not embeddings:
                self.logger.error("No embeddings returned from HuggingFace")
                return None

            self.logger.debug(f"Received {len(embeddings)} embeddings")
            all_embeddings.extend(embeddings)

        self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
        return all_embeddings

    def clean_content(self, text: str) -> str:
        text = re.sub(r'\[.*?\]\(.*?\)', '', text)
        text = re.sub(r'\[[^\]]*\]', '', text)
        text = re.sub(r'\n+', '\n', text).strip()
        return text

    def web_search(self, query: str):
        """HuggingFace Inference API has no native web search — returns a not-supported notice."""
        self.logger.warning("HuggingFaceProvider.web_search is not natively supported.")
        return {
            "text": "Web search is not natively supported by the HuggingFace Inference API.",
            "references": []
        }