riazmo commited on
Commit
bd004b0
·
verified ·
1 Parent(s): 5c4c6c5

Upload hf_inference.py

Browse files
Files changed (1) hide show
  1. core/hf_inference.py +608 -0
core/hf_inference.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Inference Client
3
+ Design System Extractor v2
4
+
5
+ Handles all LLM inference calls using HuggingFace Inference API.
6
+ Supports diverse models from different providers for specialized tasks.
7
+ """
8
+
9
+ import os
10
+ from typing import Optional, AsyncGenerator
11
+ from dataclasses import dataclass
12
+ from huggingface_hub import InferenceClient, AsyncInferenceClient
13
+
14
+ from config.settings import get_settings
15
+
16
+
17
+ @dataclass
18
+ class ModelInfo:
19
+ """Information about a model."""
20
+ model_id: str
21
+ provider: str
22
+ context_length: int
23
+ strengths: list[str]
24
+ best_for: str
25
+ tier: str # "free", "pro", "pro+"
26
+
27
+
28
+ # =============================================================================
29
+ # COMPREHENSIVE MODEL REGISTRY — Organized by Provider
30
+ # =============================================================================
31
+
32
+ AVAILABLE_MODELS = {
33
+ # =========================================================================
34
+ # META — Llama Family (Best for reasoning)
35
+ # =========================================================================
36
+ "meta-llama/Llama-3.1-405B-Instruct": ModelInfo(
37
+ model_id="meta-llama/Llama-3.1-405B-Instruct",
38
+ provider="Meta",
39
+ context_length=128000,
40
+ strengths=["Best reasoning", "Massive knowledge", "Complex analysis"],
41
+ best_for="Agent 3 (Advisor) — PREMIUM CHOICE",
42
+ tier="pro+"
43
+ ),
44
+ "meta-llama/Llama-3.1-70B-Instruct": ModelInfo(
45
+ model_id="meta-llama/Llama-3.1-70B-Instruct",
46
+ provider="Meta",
47
+ context_length=128000,
48
+ strengths=["Excellent reasoning", "Long context", "Design knowledge"],
49
+ best_for="Agent 3 (Advisor) — RECOMMENDED",
50
+ tier="pro"
51
+ ),
52
+ "meta-llama/Llama-3.1-8B-Instruct": ModelInfo(
53
+ model_id="meta-llama/Llama-3.1-8B-Instruct",
54
+ provider="Meta",
55
+ context_length=128000,
56
+ strengths=["Fast", "Good reasoning for size", "Long context"],
57
+ best_for="Budget Agent 3 fallback",
58
+ tier="free"
59
+ ),
60
+
61
+ # =========================================================================
62
+ # MISTRAL — European Excellence
63
+ # =========================================================================
64
+ "mistralai/Mixtral-8x22B-Instruct-v0.1": ModelInfo(
65
+ model_id="mistralai/Mixtral-8x22B-Instruct-v0.1",
66
+ provider="Mistral",
67
+ context_length=65536,
68
+ strengths=["Large MoE", "Strong reasoning", "Efficient"],
69
+ best_for="Agent 3 (Advisor) — Pro alternative",
70
+ tier="pro"
71
+ ),
72
+ "mistralai/Mixtral-8x7B-Instruct-v0.1": ModelInfo(
73
+ model_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
74
+ provider="Mistral",
75
+ context_length=32768,
76
+ strengths=["Good MoE efficiency", "Solid reasoning"],
77
+ best_for="Agent 3 (Advisor) — Free tier option",
78
+ tier="free"
79
+ ),
80
+ "mistralai/Mistral-7B-Instruct-v0.3": ModelInfo(
81
+ model_id="mistralai/Mistral-7B-Instruct-v0.3",
82
+ provider="Mistral",
83
+ context_length=32768,
84
+ strengths=["Fast", "Good instruction following"],
85
+ best_for="General fallback",
86
+ tier="free"
87
+ ),
88
+ "mistralai/Codestral-22B-v0.1": ModelInfo(
89
+ model_id="mistralai/Codestral-22B-v0.1",
90
+ provider="Mistral",
91
+ context_length=32768,
92
+ strengths=["Code specialist", "JSON generation", "Structured output"],
93
+ best_for="Agent 4 (Generator) — RECOMMENDED",
94
+ tier="pro"
95
+ ),
96
+
97
+ # =========================================================================
98
+ # COHERE — Command R Family (Analysis & Retrieval)
99
+ # =========================================================================
100
+ "CohereForAI/c4ai-command-r-plus": ModelInfo(
101
+ model_id="CohereForAI/c4ai-command-r-plus",
102
+ provider="Cohere",
103
+ context_length=128000,
104
+ strengths=["Excellent analysis", "RAG optimized", "Long context"],
105
+ best_for="Agent 3 (Advisor) — Great for research tasks",
106
+ tier="pro"
107
+ ),
108
+ "CohereForAI/c4ai-command-r-v01": ModelInfo(
109
+ model_id="CohereForAI/c4ai-command-r-v01",
110
+ provider="Cohere",
111
+ context_length=128000,
112
+ strengths=["Good analysis", "Efficient"],
113
+ best_for="Agent 3 budget option",
114
+ tier="free"
115
+ ),
116
+
117
+ # =========================================================================
118
+ # GOOGLE — Gemma Family
119
+ # =========================================================================
120
+ "google/gemma-2-27b-it": ModelInfo(
121
+ model_id="google/gemma-2-27b-it",
122
+ provider="Google",
123
+ context_length=8192,
124
+ strengths=["Strong instruction following", "Good balance"],
125
+ best_for="Agent 2 (Normalizer) — Quality option",
126
+ tier="pro"
127
+ ),
128
+ "google/gemma-2-9b-it": ModelInfo(
129
+ model_id="google/gemma-2-9b-it",
130
+ provider="Google",
131
+ context_length=8192,
132
+ strengths=["Fast", "Good instruction following"],
133
+ best_for="Agent 2 (Normalizer) — Balanced",
134
+ tier="free"
135
+ ),
136
+
137
+ # =========================================================================
138
+ # MICROSOFT — Phi Family (Small but Mighty)
139
+ # =========================================================================
140
+ "microsoft/Phi-3.5-mini-instruct": ModelInfo(
141
+ model_id="microsoft/Phi-3.5-mini-instruct",
142
+ provider="Microsoft",
143
+ context_length=128000,
144
+ strengths=["Very fast", "Great structured output", "Long context"],
145
+ best_for="Agent 2 (Normalizer) — RECOMMENDED",
146
+ tier="free"
147
+ ),
148
+ "microsoft/Phi-3-medium-4k-instruct": ModelInfo(
149
+ model_id="microsoft/Phi-3-medium-4k-instruct",
150
+ provider="Microsoft",
151
+ context_length=4096,
152
+ strengths=["Fast", "Good for simple tasks"],
153
+ best_for="Simple naming tasks",
154
+ tier="free"
155
+ ),
156
+
157
+ # =========================================================================
158
+ # QWEN — Alibaba Family
159
+ # =========================================================================
160
+ "Qwen/Qwen2.5-72B-Instruct": ModelInfo(
161
+ model_id="Qwen/Qwen2.5-72B-Instruct",
162
+ provider="Alibaba",
163
+ context_length=32768,
164
+ strengths=["Strong reasoning", "Multilingual", "Good design knowledge"],
165
+ best_for="Agent 3 (Advisor) — Alternative",
166
+ tier="pro"
167
+ ),
168
+ "Qwen/Qwen2.5-32B-Instruct": ModelInfo(
169
+ model_id="Qwen/Qwen2.5-32B-Instruct",
170
+ provider="Alibaba",
171
+ context_length=32768,
172
+ strengths=["Good balance", "Multilingual"],
173
+ best_for="Medium-tier option",
174
+ tier="pro"
175
+ ),
176
+ "Qwen/Qwen2.5-Coder-32B-Instruct": ModelInfo(
177
+ model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
178
+ provider="Alibaba",
179
+ context_length=32768,
180
+ strengths=["Code specialist", "JSON/structured output"],
181
+ best_for="Agent 4 (Generator) — Alternative",
182
+ tier="pro"
183
+ ),
184
+ "Qwen/Qwen2.5-7B-Instruct": ModelInfo(
185
+ model_id="Qwen/Qwen2.5-7B-Instruct",
186
+ provider="Alibaba",
187
+ context_length=32768,
188
+ strengths=["Fast", "Good all-rounder"],
189
+ best_for="General fallback",
190
+ tier="free"
191
+ ),
192
+
193
+ # =========================================================================
194
+ # DEEPSEEK — Code Specialists
195
+ # =========================================================================
196
+ "deepseek-ai/deepseek-coder-33b-instruct": ModelInfo(
197
+ model_id="deepseek-ai/deepseek-coder-33b-instruct",
198
+ provider="DeepSeek",
199
+ context_length=16384,
200
+ strengths=["Excellent code generation", "JSON specialist"],
201
+ best_for="Agent 4 (Generator) — Code focused",
202
+ tier="pro"
203
+ ),
204
+ "deepseek-ai/DeepSeek-V2.5": ModelInfo(
205
+ model_id="deepseek-ai/DeepSeek-V2.5",
206
+ provider="DeepSeek",
207
+ context_length=32768,
208
+ strengths=["Strong reasoning", "Good code"],
209
+ best_for="Multi-purpose",
210
+ tier="pro"
211
+ ),
212
+
213
+ # =========================================================================
214
+ # BIGCODE — StarCoder Family
215
+ # =========================================================================
216
+ "bigcode/starcoder2-15b-instruct-v0.1": ModelInfo(
217
+ model_id="bigcode/starcoder2-15b-instruct-v0.1",
218
+ provider="BigCode",
219
+ context_length=16384,
220
+ strengths=["Code generation", "Multiple languages"],
221
+ best_for="Agent 4 (Generator) — Open source code model",
222
+ tier="free"
223
+ ),
224
+ }
225
+
226
+
227
+ # =============================================================================
228
+ # RECOMMENDED CONFIGURATIONS BY TIER
229
+ # =============================================================================
230
+
231
+ MODEL_PRESETS = {
232
+ "budget": {
233
+ "name": "Budget (Free Tier)",
234
+ "description": "Best free models for each task",
235
+ "agent2": "microsoft/Phi-3.5-mini-instruct",
236
+ "agent3": "mistralai/Mixtral-8x7B-Instruct-v0.1",
237
+ "agent4": "bigcode/starcoder2-15b-instruct-v0.1",
238
+ "fallback": "mistralai/Mistral-7B-Instruct-v0.3",
239
+ },
240
+ "balanced": {
241
+ "name": "Balanced (Pro Tier)",
242
+ "description": "Good quality/cost balance",
243
+ "agent2": "google/gemma-2-9b-it",
244
+ "agent3": "meta-llama/Llama-3.1-70B-Instruct",
245
+ "agent4": "mistralai/Codestral-22B-v0.1",
246
+ "fallback": "Qwen/Qwen2.5-7B-Instruct",
247
+ },
248
+ "quality": {
249
+ "name": "Maximum Quality (Pro+)",
250
+ "description": "Best models regardless of cost",
251
+ "agent2": "google/gemma-2-27b-it",
252
+ "agent3": "meta-llama/Llama-3.1-405B-Instruct",
253
+ "agent4": "deepseek-ai/deepseek-coder-33b-instruct",
254
+ "fallback": "meta-llama/Llama-3.1-8B-Instruct",
255
+ },
256
+ "diverse": {
257
+ "name": "Diverse Providers",
258
+ "description": "One model from each major provider",
259
+ "agent2": "microsoft/Phi-3.5-mini-instruct", # Microsoft
260
+ "agent3": "CohereForAI/c4ai-command-r-plus", # Cohere
261
+ "agent4": "mistralai/Codestral-22B-v0.1", # Mistral
262
+ "fallback": "meta-llama/Llama-3.1-8B-Instruct", # Meta
263
+ },
264
+ }
265
+
266
+
267
+ # =============================================================================
268
+ # AGENT-SPECIFIC RECOMMENDATIONS
269
+ # =============================================================================
270
+
271
+ AGENT_MODEL_RECOMMENDATIONS = {
272
+ "crawler": {
273
+ "requires_llm": False,
274
+ "notes": "Pure rule-based extraction using Playwright + CSS parsing"
275
+ },
276
+ "extractor": {
277
+ "requires_llm": False,
278
+ "notes": "Pure rule-based extraction using Playwright + CSS parsing"
279
+ },
280
+ "normalizer": {
281
+ "requires_llm": True,
282
+ "task": "Token naming, duplicate detection, pattern inference",
283
+ "needs": ["Fast inference", "Good instruction following", "Structured output"],
284
+ "recommended": [
285
+ ("microsoft/Phi-3.5-mini-instruct", "BEST — Fast, great structured output"),
286
+ ("google/gemma-2-9b-it", "Good balance of speed and quality"),
287
+ ("Qwen/Qwen2.5-7B-Instruct", "Reliable all-rounder"),
288
+ ],
289
+ "temperature": 0.2,
290
+ },
291
+ "advisor": {
292
+ "requires_llm": True,
293
+ "task": "Design system analysis, best practice recommendations",
294
+ "needs": ["Strong reasoning", "Design knowledge", "Creative suggestions"],
295
+ "recommended": [
296
+ ("meta-llama/Llama-3.1-70B-Instruct", "BEST — Excellent reasoning"),
297
+ ("CohereForAI/c4ai-command-r-plus", "Great for analysis tasks"),
298
+ ("Qwen/Qwen2.5-72B-Instruct", "Strong alternative"),
299
+ ("mistralai/Mixtral-8x7B-Instruct-v0.1", "Best free option"),
300
+ ],
301
+ "temperature": 0.4,
302
+ },
303
+ "generator": {
304
+ "requires_llm": True,
305
+ "task": "Generate JSON tokens, CSS variables, structured output",
306
+ "needs": ["Code generation", "JSON formatting", "Schema adherence"],
307
+ "recommended": [
308
+ ("mistralai/Codestral-22B-v0.1", "BEST — Mistral's code model"),
309
+ ("deepseek-ai/deepseek-coder-33b-instruct", "Excellent code specialist"),
310
+ ("Qwen/Qwen2.5-Coder-32B-Instruct", "Strong code model"),
311
+ ("bigcode/starcoder2-15b-instruct-v0.1", "Best free option"),
312
+ ],
313
+ "temperature": 0.1,
314
+ },
315
+ }
316
+
317
+
318
+ # =============================================================================
319
+ # INFERENCE CLIENT
320
+ # =============================================================================
321
+
322
+ class HFInferenceClient:
323
+ """
324
+ Wrapper around HuggingFace Inference API.
325
+
326
+ Handles model selection, retries, and fallbacks.
327
+ """
328
+
329
+ def __init__(self):
330
+ self.settings = get_settings()
331
+ # Read token fresh from env — the Settings singleton may have been
332
+ # created before the user entered their token via the Gradio UI.
333
+ self.token = os.getenv("HF_TOKEN", "") or self.settings.hf.hf_token
334
+
335
+ if not self.token:
336
+ raise ValueError("HF_TOKEN is required for inference")
337
+
338
+ # Let huggingface_hub route to the best available provider automatically.
339
+ # Do NOT set base_url (overrides per-model routing) or
340
+ # provider="hf-inference" (that provider no longer hosts most models).
341
+ # The default provider="auto" picks the first available third-party
342
+ # provider (novita, together, cerebras, etc.) for each model.
343
+ self.sync_client = InferenceClient(token=self.token)
344
+ self.async_client = AsyncInferenceClient(token=self.token)
345
+
346
+ def get_model_for_agent(self, agent_name: str) -> str:
347
+ """Get the appropriate model for an agent."""
348
+ return self.settings.get_model_for_agent(agent_name)
349
+
350
+ def get_temperature_for_agent(self, agent_name: str) -> float:
351
+ """Get recommended temperature for an agent."""
352
+ temps = {
353
+ # Legacy agents
354
+ "normalizer": 0.2, # Consistent naming
355
+ "advisor": 0.4, # Creative recommendations
356
+ "generator": 0.1, # Precise formatting
357
+ # Stage 2 agents — tuned per persona
358
+ "brand_identifier": 0.4, # AURORA — creative color reasoning
359
+ "benchmark_advisor": 0.25, # ATLAS — analytical comparison
360
+ "best_practices_validator": 0.2, # SENTINEL — precise rule-checking
361
+ "head_synthesizer": 0.3, # NEXUS — balanced synthesis
362
+ }
363
+ return temps.get(agent_name, 0.3)
364
+
365
+ def _build_messages(
366
+ self,
367
+ system_prompt: str,
368
+ user_message: str,
369
+ examples: list[dict] = None
370
+ ) -> list[dict]:
371
+ """Build message list for chat completion."""
372
+ messages = []
373
+
374
+ if system_prompt:
375
+ messages.append({"role": "system", "content": system_prompt})
376
+
377
+ if examples:
378
+ for example in examples:
379
+ messages.append({"role": "user", "content": example["user"]})
380
+ messages.append({"role": "assistant", "content": example["assistant"]})
381
+
382
+ messages.append({"role": "user", "content": user_message})
383
+
384
+ return messages
385
+
386
+ def complete(
387
+ self,
388
+ agent_name: str,
389
+ system_prompt: str,
390
+ user_message: str,
391
+ examples: list[dict] = None,
392
+ max_tokens: int = None,
393
+ temperature: float = None,
394
+ json_mode: bool = False,
395
+ ) -> str:
396
+ """
397
+ Synchronous completion.
398
+
399
+ Args:
400
+ agent_name: Which agent is making the call (for model selection)
401
+ system_prompt: System instructions
402
+ user_message: User input
403
+ examples: Optional few-shot examples
404
+ max_tokens: Max tokens to generate
405
+ temperature: Sampling temperature (uses agent default if not specified)
406
+ json_mode: If True, instruct model to output JSON
407
+
408
+ Returns:
409
+ Generated text
410
+ """
411
+ model = self.get_model_for_agent(agent_name)
412
+ max_tokens = max_tokens or self.settings.hf.max_new_tokens
413
+ temperature = temperature or self.get_temperature_for_agent(agent_name)
414
+
415
+ # Build messages
416
+ if json_mode:
417
+ system_prompt = f"{system_prompt}\n\nYou must respond with valid JSON only. No markdown, no explanation, just JSON."
418
+
419
+ messages = self._build_messages(system_prompt, user_message, examples)
420
+
421
+ try:
422
+ response = self.sync_client.chat_completion(
423
+ model=model,
424
+ messages=messages,
425
+ max_tokens=max_tokens,
426
+ temperature=temperature,
427
+ )
428
+ return response.choices[0].message.content
429
+
430
+ except Exception as e:
431
+ error_msg = str(e)
432
+ print(f"[HF] Primary model {model} failed: {error_msg[:120]}")
433
+ fallback = self.settings.models.fallback_model
434
+ if fallback and fallback != model:
435
+ print(f"[HF] Trying fallback: {fallback}")
436
+ try:
437
+ response = self.sync_client.chat_completion(
438
+ model=fallback,
439
+ messages=messages,
440
+ max_tokens=max_tokens,
441
+ temperature=temperature,
442
+ )
443
+ return response.choices[0].message.content
444
+ except Exception as fallback_err:
445
+ print(f"[HF] Fallback {fallback} also failed: {str(fallback_err)[:120]}")
446
+ raise fallback_err
447
+ raise e
448
+
449
+ async def complete_async(
450
+ self,
451
+ agent_name: str,
452
+ system_prompt: str,
453
+ user_message: str,
454
+ examples: list[dict] = None,
455
+ max_tokens: int = None,
456
+ temperature: float = None,
457
+ json_mode: bool = False,
458
+ ) -> str:
459
+ """
460
+ Asynchronous completion.
461
+
462
+ Same parameters as complete().
463
+ """
464
+ model = self.get_model_for_agent(agent_name)
465
+ max_tokens = max_tokens or self.settings.hf.max_new_tokens
466
+ temperature = temperature or self.get_temperature_for_agent(agent_name)
467
+
468
+ if json_mode:
469
+ system_prompt = f"{system_prompt}\n\nYou must respond with valid JSON only. No markdown, no explanation, just JSON."
470
+
471
+ messages = self._build_messages(system_prompt, user_message, examples)
472
+
473
+ try:
474
+ response = await self.async_client.chat_completion(
475
+ model=model,
476
+ messages=messages,
477
+ max_tokens=max_tokens,
478
+ temperature=temperature,
479
+ )
480
+ return response.choices[0].message.content
481
+
482
+ except Exception as e:
483
+ error_msg = str(e)
484
+ print(f"[HF] Primary model {model} failed: {error_msg[:120]}")
485
+ fallback = self.settings.models.fallback_model
486
+ if fallback and fallback != model:
487
+ print(f"[HF] Trying fallback: {fallback}")
488
+ try:
489
+ response = await self.async_client.chat_completion(
490
+ model=fallback,
491
+ messages=messages,
492
+ max_tokens=max_tokens,
493
+ temperature=temperature,
494
+ )
495
+ return response.choices[0].message.content
496
+ except Exception as fallback_err:
497
+ print(f"[HF] Fallback {fallback} also failed: {str(fallback_err)[:120]}")
498
+ raise fallback_err
499
+ raise e
500
+
501
+ async def stream_async(
502
+ self,
503
+ agent_name: str,
504
+ system_prompt: str,
505
+ user_message: str,
506
+ max_tokens: int = None,
507
+ temperature: float = None,
508
+ ) -> AsyncGenerator[str, None]:
509
+ """
510
+ Async streaming completion.
511
+
512
+ Yields tokens as they are generated.
513
+ """
514
+ model = self.get_model_for_agent(agent_name)
515
+ max_tokens = max_tokens or self.settings.hf.max_new_tokens
516
+ temperature = temperature or self.get_temperature_for_agent(agent_name)
517
+
518
+ messages = self._build_messages(system_prompt, user_message)
519
+
520
+ async for chunk in await self.async_client.chat_completion(
521
+ model=model,
522
+ messages=messages,
523
+ max_tokens=max_tokens,
524
+ temperature=temperature,
525
+ stream=True,
526
+ ):
527
+ if chunk.choices[0].delta.content:
528
+ yield chunk.choices[0].delta.content
529
+
530
+
531
+ # =============================================================================
532
+ # SINGLETON & CONVENIENCE FUNCTIONS
533
+ # =============================================================================
534
+
535
+ _client: Optional[HFInferenceClient] = None
536
+
537
+
538
+ def get_inference_client() -> HFInferenceClient:
539
+ """Get or create the inference client singleton.
540
+
541
+ Re-creates the client if the token has changed (e.g. user entered it
542
+ via the Gradio UI after initial startup).
543
+ """
544
+ global _client
545
+ current_token = os.getenv("HF_TOKEN", "")
546
+ if _client is None or (_client.token != current_token and current_token):
547
+ _client = HFInferenceClient()
548
+ return _client
549
+
550
+
551
+ def complete(
552
+ agent_name: str,
553
+ system_prompt: str,
554
+ user_message: str,
555
+ **kwargs
556
+ ) -> str:
557
+ """Convenience function for sync completion."""
558
+ client = get_inference_client()
559
+ return client.complete(agent_name, system_prompt, user_message, **kwargs)
560
+
561
+
562
+ async def complete_async(
563
+ agent_name: str,
564
+ system_prompt: str,
565
+ user_message: str,
566
+ **kwargs
567
+ ) -> str:
568
+ """Convenience function for async completion."""
569
+ client = get_inference_client()
570
+ return await client.complete_async(agent_name, system_prompt, user_message, **kwargs)
571
+
572
+
573
+ def get_model_info(model_id: str) -> dict:
574
+ """Get information about a specific model."""
575
+ if model_id in AVAILABLE_MODELS:
576
+ info = AVAILABLE_MODELS[model_id]
577
+ return {
578
+ "model_id": info.model_id,
579
+ "provider": info.provider,
580
+ "context_length": info.context_length,
581
+ "strengths": info.strengths,
582
+ "best_for": info.best_for,
583
+ "tier": info.tier,
584
+ }
585
+ return {"model_id": model_id, "provider": "unknown"}
586
+
587
+
588
+ def get_models_by_provider() -> dict[str, list[str]]:
589
+ """Get all models grouped by provider."""
590
+ by_provider = {}
591
+ for model_id, info in AVAILABLE_MODELS.items():
592
+ if info.provider not in by_provider:
593
+ by_provider[info.provider] = []
594
+ by_provider[info.provider].append(model_id)
595
+ return by_provider
596
+
597
+
598
+ def get_models_by_tier(tier: str) -> list[str]:
599
+ """Get all models for a specific tier (free, pro, pro+)."""
600
+ return [
601
+ model_id for model_id, info in AVAILABLE_MODELS.items()
602
+ if info.tier == tier
603
+ ]
604
+
605
+
606
+ def get_preset_config(preset_name: str) -> dict:
607
+ """Get a preset model configuration."""
608
+ return MODEL_PRESETS.get(preset_name, MODEL_PRESETS["balanced"])