riazmo commited on
Commit
2a51260
·
verified ·
1 Parent(s): b81141e

Delete core/hf_inference.py

Browse files
Files changed (1) hide show
  1. core/hf_inference.py +0 -602
core/hf_inference.py DELETED
@@ -1,602 +0,0 @@
1
- """
2
- HuggingFace Inference Client
3
- Design System Extractor v2
4
-
5
- Handles all LLM inference calls using HuggingFace Inference API.
6
- Supports diverse models from different providers for specialized tasks.
7
- """
8
-
9
- import os
10
- from typing import Optional, AsyncGenerator
11
- from dataclasses import dataclass
12
- from huggingface_hub import InferenceClient, AsyncInferenceClient
13
-
14
- from config.settings import get_settings
15
-
16
-
17
- @dataclass
18
- class ModelInfo:
19
- """Information about a model."""
20
- model_id: str
21
- provider: str
22
- context_length: int
23
- strengths: list[str]
24
- best_for: str
25
- tier: str # "free", "pro", "pro+"
26
-
27
-
28
- # =============================================================================
29
- # COMPREHENSIVE MODEL REGISTRY — Organized by Provider
30
- # =============================================================================
31
-
32
- AVAILABLE_MODELS = {
33
- # =========================================================================
34
- # META — Llama Family (Best for reasoning)
35
- # =========================================================================
36
- "meta-llama/Llama-3.1-405B-Instruct": ModelInfo(
37
- model_id="meta-llama/Llama-3.1-405B-Instruct",
38
- provider="Meta",
39
- context_length=128000,
40
- strengths=["Best reasoning", "Massive knowledge", "Complex analysis"],
41
- best_for="Agent 3 (Advisor) — PREMIUM CHOICE",
42
- tier="pro+"
43
- ),
44
- "meta-llama/Llama-3.1-70B-Instruct": ModelInfo(
45
- model_id="meta-llama/Llama-3.1-70B-Instruct",
46
- provider="Meta",
47
- context_length=128000,
48
- strengths=["Excellent reasoning", "Long context", "Design knowledge"],
49
- best_for="Agent 3 (Advisor) — RECOMMENDED",
50
- tier="pro"
51
- ),
52
- "meta-llama/Llama-3.1-8B-Instruct": ModelInfo(
53
- model_id="meta-llama/Llama-3.1-8B-Instruct",
54
- provider="Meta",
55
- context_length=128000,
56
- strengths=["Fast", "Good reasoning for size", "Long context"],
57
- best_for="Budget Agent 3 fallback",
58
- tier="free"
59
- ),
60
-
61
- # =========================================================================
62
- # MISTRAL — European Excellence
63
- # =========================================================================
64
- "mistralai/Mixtral-8x22B-Instruct-v0.1": ModelInfo(
65
- model_id="mistralai/Mixtral-8x22B-Instruct-v0.1",
66
- provider="Mistral",
67
- context_length=65536,
68
- strengths=["Large MoE", "Strong reasoning", "Efficient"],
69
- best_for="Agent 3 (Advisor) — Pro alternative",
70
- tier="pro"
71
- ),
72
- "mistralai/Mixtral-8x7B-Instruct-v0.1": ModelInfo(
73
- model_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
74
- provider="Mistral",
75
- context_length=32768,
76
- strengths=["Good MoE efficiency", "Solid reasoning"],
77
- best_for="Agent 3 (Advisor) — Free tier option",
78
- tier="free"
79
- ),
80
- "mistralai/Mistral-7B-Instruct-v0.3": ModelInfo(
81
- model_id="mistralai/Mistral-7B-Instruct-v0.3",
82
- provider="Mistral",
83
- context_length=32768,
84
- strengths=["Fast", "Good instruction following"],
85
- best_for="General fallback",
86
- tier="free"
87
- ),
88
- "mistralai/Codestral-22B-v0.1": ModelInfo(
89
- model_id="mistralai/Codestral-22B-v0.1",
90
- provider="Mistral",
91
- context_length=32768,
92
- strengths=["Code specialist", "JSON generation", "Structured output"],
93
- best_for="Agent 4 (Generator) — RECOMMENDED",
94
- tier="pro"
95
- ),
96
-
97
- # =========================================================================
98
- # COHERE — Command R Family (Analysis & Retrieval)
99
- # =========================================================================
100
- "CohereForAI/c4ai-command-r-plus": ModelInfo(
101
- model_id="CohereForAI/c4ai-command-r-plus",
102
- provider="Cohere",
103
- context_length=128000,
104
- strengths=["Excellent analysis", "RAG optimized", "Long context"],
105
- best_for="Agent 3 (Advisor) — Great for research tasks",
106
- tier="pro"
107
- ),
108
- "CohereForAI/c4ai-command-r-v01": ModelInfo(
109
- model_id="CohereForAI/c4ai-command-r-v01",
110
- provider="Cohere",
111
- context_length=128000,
112
- strengths=["Good analysis", "Efficient"],
113
- best_for="Agent 3 budget option",
114
- tier="free"
115
- ),
116
-
117
- # =========================================================================
118
- # GOOGLE — Gemma Family
119
- # =========================================================================
120
- "google/gemma-2-27b-it": ModelInfo(
121
- model_id="google/gemma-2-27b-it",
122
- provider="Google",
123
- context_length=8192,
124
- strengths=["Strong instruction following", "Good balance"],
125
- best_for="Agent 2 (Normalizer) — Quality option",
126
- tier="pro"
127
- ),
128
- "google/gemma-2-9b-it": ModelInfo(
129
- model_id="google/gemma-2-9b-it",
130
- provider="Google",
131
- context_length=8192,
132
- strengths=["Fast", "Good instruction following"],
133
- best_for="Agent 2 (Normalizer) — Balanced",
134
- tier="free"
135
- ),
136
-
137
- # =========================================================================
138
- # MICROSOFT — Phi Family (Small but Mighty)
139
- # =========================================================================
140
- "microsoft/Phi-3.5-mini-instruct": ModelInfo(
141
- model_id="microsoft/Phi-3.5-mini-instruct",
142
- provider="Microsoft",
143
- context_length=128000,
144
- strengths=["Very fast", "Great structured output", "Long context"],
145
- best_for="Agent 2 (Normalizer) — RECOMMENDED",
146
- tier="free"
147
- ),
148
- "microsoft/Phi-3-medium-4k-instruct": ModelInfo(
149
- model_id="microsoft/Phi-3-medium-4k-instruct",
150
- provider="Microsoft",
151
- context_length=4096,
152
- strengths=["Fast", "Good for simple tasks"],
153
- best_for="Simple naming tasks",
154
- tier="free"
155
- ),
156
-
157
- # =========================================================================
158
- # QWEN — Alibaba Family
159
- # =========================================================================
160
- "Qwen/Qwen2.5-72B-Instruct": ModelInfo(
161
- model_id="Qwen/Qwen2.5-72B-Instruct",
162
- provider="Alibaba",
163
- context_length=32768,
164
- strengths=["Strong reasoning", "Multilingual", "Good design knowledge"],
165
- best_for="Agent 3 (Advisor) — Alternative",
166
- tier="pro"
167
- ),
168
- "Qwen/Qwen2.5-32B-Instruct": ModelInfo(
169
- model_id="Qwen/Qwen2.5-32B-Instruct",
170
- provider="Alibaba",
171
- context_length=32768,
172
- strengths=["Good balance", "Multilingual"],
173
- best_for="Medium-tier option",
174
- tier="pro"
175
- ),
176
- "Qwen/Qwen2.5-Coder-32B-Instruct": ModelInfo(
177
- model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
178
- provider="Alibaba",
179
- context_length=32768,
180
- strengths=["Code specialist", "JSON/structured output"],
181
- best_for="Agent 4 (Generator) — Alternative",
182
- tier="pro"
183
- ),
184
- "Qwen/Qwen2.5-7B-Instruct": ModelInfo(
185
- model_id="Qwen/Qwen2.5-7B-Instruct",
186
- provider="Alibaba",
187
- context_length=32768,
188
- strengths=["Fast", "Good all-rounder"],
189
- best_for="General fallback",
190
- tier="free"
191
- ),
192
-
193
- # =========================================================================
194
- # DEEPSEEK — Code Specialists
195
- # =========================================================================
196
- "deepseek-ai/deepseek-coder-33b-instruct": ModelInfo(
197
- model_id="deepseek-ai/deepseek-coder-33b-instruct",
198
- provider="DeepSeek",
199
- context_length=16384,
200
- strengths=["Excellent code generation", "JSON specialist"],
201
- best_for="Agent 4 (Generator) — Code focused",
202
- tier="pro"
203
- ),
204
- "deepseek-ai/DeepSeek-V2.5": ModelInfo(
205
- model_id="deepseek-ai/DeepSeek-V2.5",
206
- provider="DeepSeek",
207
- context_length=32768,
208
- strengths=["Strong reasoning", "Good code"],
209
- best_for="Multi-purpose",
210
- tier="pro"
211
- ),
212
-
213
- # =========================================================================
214
- # BIGCODE — StarCoder Family
215
- # =========================================================================
216
- "bigcode/starcoder2-15b-instruct-v0.1": ModelInfo(
217
- model_id="bigcode/starcoder2-15b-instruct-v0.1",
218
- provider="BigCode",
219
- context_length=16384,
220
- strengths=["Code generation", "Multiple languages"],
221
- best_for="Agent 4 (Generator) — Open source code model",
222
- tier="free"
223
- ),
224
- }
225
-
226
-
227
- # =============================================================================
228
- # RECOMMENDED CONFIGURATIONS BY TIER
229
- # =============================================================================
230
-
231
- MODEL_PRESETS = {
232
- "budget": {
233
- "name": "Budget (Free Tier)",
234
- "description": "Best free models for each task",
235
- "agent2": "microsoft/Phi-3.5-mini-instruct",
236
- "agent3": "mistralai/Mixtral-8x7B-Instruct-v0.1",
237
- "agent4": "bigcode/starcoder2-15b-instruct-v0.1",
238
- "fallback": "mistralai/Mistral-7B-Instruct-v0.3",
239
- },
240
- "balanced": {
241
- "name": "Balanced (Pro Tier)",
242
- "description": "Good quality/cost balance",
243
- "agent2": "google/gemma-2-9b-it",
244
- "agent3": "meta-llama/Llama-3.1-70B-Instruct",
245
- "agent4": "mistralai/Codestral-22B-v0.1",
246
- "fallback": "Qwen/Qwen2.5-7B-Instruct",
247
- },
248
- "quality": {
249
- "name": "Maximum Quality (Pro+)",
250
- "description": "Best models regardless of cost",
251
- "agent2": "google/gemma-2-27b-it",
252
- "agent3": "meta-llama/Llama-3.1-405B-Instruct",
253
- "agent4": "deepseek-ai/deepseek-coder-33b-instruct",
254
- "fallback": "meta-llama/Llama-3.1-8B-Instruct",
255
- },
256
- "diverse": {
257
- "name": "Diverse Providers",
258
- "description": "One model from each major provider",
259
- "agent2": "microsoft/Phi-3.5-mini-instruct", # Microsoft
260
- "agent3": "CohereForAI/c4ai-command-r-plus", # Cohere
261
- "agent4": "mistralai/Codestral-22B-v0.1", # Mistral
262
- "fallback": "meta-llama/Llama-3.1-8B-Instruct", # Meta
263
- },
264
- }
265
-
266
-
267
- # =============================================================================
268
- # AGENT-SPECIFIC RECOMMENDATIONS
269
- # =============================================================================
270
-
271
- AGENT_MODEL_RECOMMENDATIONS = {
272
- "crawler": {
273
- "requires_llm": False,
274
- "notes": "Pure rule-based extraction using Playwright + CSS parsing"
275
- },
276
- "extractor": {
277
- "requires_llm": False,
278
- "notes": "Pure rule-based extraction using Playwright + CSS parsing"
279
- },
280
- "normalizer": {
281
- "requires_llm": True,
282
- "task": "Token naming, duplicate detection, pattern inference",
283
- "needs": ["Fast inference", "Good instruction following", "Structured output"],
284
- "recommended": [
285
- ("microsoft/Phi-3.5-mini-instruct", "BEST — Fast, great structured output"),
286
- ("google/gemma-2-9b-it", "Good balance of speed and quality"),
287
- ("Qwen/Qwen2.5-7B-Instruct", "Reliable all-rounder"),
288
- ],
289
- "temperature": 0.2,
290
- },
291
- "advisor": {
292
- "requires_llm": True,
293
- "task": "Design system analysis, best practice recommendations",
294
- "needs": ["Strong reasoning", "Design knowledge", "Creative suggestions"],
295
- "recommended": [
296
- ("meta-llama/Llama-3.1-70B-Instruct", "BEST — Excellent reasoning"),
297
- ("CohereForAI/c4ai-command-r-plus", "Great for analysis tasks"),
298
- ("Qwen/Qwen2.5-72B-Instruct", "Strong alternative"),
299
- ("mistralai/Mixtral-8x7B-Instruct-v0.1", "Best free option"),
300
- ],
301
- "temperature": 0.4,
302
- },
303
- "generator": {
304
- "requires_llm": True,
305
- "task": "Generate JSON tokens, CSS variables, structured output",
306
- "needs": ["Code generation", "JSON formatting", "Schema adherence"],
307
- "recommended": [
308
- ("mistralai/Codestral-22B-v0.1", "BEST — Mistral's code model"),
309
- ("deepseek-ai/deepseek-coder-33b-instruct", "Excellent code specialist"),
310
- ("Qwen/Qwen2.5-Coder-32B-Instruct", "Strong code model"),
311
- ("bigcode/starcoder2-15b-instruct-v0.1", "Best free option"),
312
- ],
313
- "temperature": 0.1,
314
- },
315
- }
316
-
317
-
318
- # =============================================================================
319
- # INFERENCE CLIENT
320
- # =============================================================================
321
-
322
- class HFInferenceClient:
323
- """
324
- Wrapper around HuggingFace Inference API.
325
-
326
- Handles model selection, retries, and fallbacks.
327
- """
328
-
329
- def __init__(self):
330
- self.settings = get_settings()
331
- # Read token fresh from env — the Settings singleton may have been
332
- # created before the user entered their token via the Gradio UI.
333
- self.token = os.getenv("HF_TOKEN", "") or self.settings.hf.hf_token
334
-
335
- if not self.token:
336
- raise ValueError("HF_TOKEN is required for inference")
337
-
338
- # Let huggingface_hub route to the best available provider automatically.
339
- # Do NOT set base_url (overrides per-model routing) or
340
- # provider="hf-inference" (that provider no longer hosts most models).
341
- # The default provider="auto" picks the first available third-party
342
- # provider (novita, together, cerebras, etc.) for each model.
343
- self.sync_client = InferenceClient(token=self.token)
344
- self.async_client = AsyncInferenceClient(token=self.token)
345
-
346
- def get_model_for_agent(self, agent_name: str) -> str:
347
- """Get the appropriate model for an agent."""
348
- return self.settings.get_model_for_agent(agent_name)
349
-
350
- def get_temperature_for_agent(self, agent_name: str) -> float:
351
- """Get recommended temperature for an agent."""
352
- temps = {
353
- "normalizer": 0.2, # Consistent naming
354
- "advisor": 0.4, # Creative recommendations
355
- "generator": 0.1, # Precise formatting
356
- }
357
- return temps.get(agent_name, 0.3)
358
-
359
- def _build_messages(
360
- self,
361
- system_prompt: str,
362
- user_message: str,
363
- examples: list[dict] = None
364
- ) -> list[dict]:
365
- """Build message list for chat completion."""
366
- messages = []
367
-
368
- if system_prompt:
369
- messages.append({"role": "system", "content": system_prompt})
370
-
371
- if examples:
372
- for example in examples:
373
- messages.append({"role": "user", "content": example["user"]})
374
- messages.append({"role": "assistant", "content": example["assistant"]})
375
-
376
- messages.append({"role": "user", "content": user_message})
377
-
378
- return messages
379
-
380
- def complete(
381
- self,
382
- agent_name: str,
383
- system_prompt: str,
384
- user_message: str,
385
- examples: list[dict] = None,
386
- max_tokens: int = None,
387
- temperature: float = None,
388
- json_mode: bool = False,
389
- ) -> str:
390
- """
391
- Synchronous completion.
392
-
393
- Args:
394
- agent_name: Which agent is making the call (for model selection)
395
- system_prompt: System instructions
396
- user_message: User input
397
- examples: Optional few-shot examples
398
- max_tokens: Max tokens to generate
399
- temperature: Sampling temperature (uses agent default if not specified)
400
- json_mode: If True, instruct model to output JSON
401
-
402
- Returns:
403
- Generated text
404
- """
405
- model = self.get_model_for_agent(agent_name)
406
- max_tokens = max_tokens or self.settings.hf.max_new_tokens
407
- temperature = temperature or self.get_temperature_for_agent(agent_name)
408
-
409
- # Build messages
410
- if json_mode:
411
- system_prompt = f"{system_prompt}\n\nYou must respond with valid JSON only. No markdown, no explanation, just JSON."
412
-
413
- messages = self._build_messages(system_prompt, user_message, examples)
414
-
415
- try:
416
- response = self.sync_client.chat_completion(
417
- model=model,
418
- messages=messages,
419
- max_tokens=max_tokens,
420
- temperature=temperature,
421
- )
422
- return response.choices[0].message.content
423
-
424
- except Exception as e:
425
- error_msg = str(e)
426
- print(f"[HF] Primary model {model} failed: {error_msg[:120]}")
427
- fallback = self.settings.models.fallback_model
428
- if fallback and fallback != model:
429
- print(f"[HF] Trying fallback: {fallback}")
430
- try:
431
- response = self.sync_client.chat_completion(
432
- model=fallback,
433
- messages=messages,
434
- max_tokens=max_tokens,
435
- temperature=temperature,
436
- )
437
- return response.choices[0].message.content
438
- except Exception as fallback_err:
439
- print(f"[HF] Fallback {fallback} also failed: {str(fallback_err)[:120]}")
440
- raise fallback_err
441
- raise e
442
-
443
- async def complete_async(
444
- self,
445
- agent_name: str,
446
- system_prompt: str,
447
- user_message: str,
448
- examples: list[dict] = None,
449
- max_tokens: int = None,
450
- temperature: float = None,
451
- json_mode: bool = False,
452
- ) -> str:
453
- """
454
- Asynchronous completion.
455
-
456
- Same parameters as complete().
457
- """
458
- model = self.get_model_for_agent(agent_name)
459
- max_tokens = max_tokens or self.settings.hf.max_new_tokens
460
- temperature = temperature or self.get_temperature_for_agent(agent_name)
461
-
462
- if json_mode:
463
- system_prompt = f"{system_prompt}\n\nYou must respond with valid JSON only. No markdown, no explanation, just JSON."
464
-
465
- messages = self._build_messages(system_prompt, user_message, examples)
466
-
467
- try:
468
- response = await self.async_client.chat_completion(
469
- model=model,
470
- messages=messages,
471
- max_tokens=max_tokens,
472
- temperature=temperature,
473
- )
474
- return response.choices[0].message.content
475
-
476
- except Exception as e:
477
- error_msg = str(e)
478
- print(f"[HF] Primary model {model} failed: {error_msg[:120]}")
479
- fallback = self.settings.models.fallback_model
480
- if fallback and fallback != model:
481
- print(f"[HF] Trying fallback: {fallback}")
482
- try:
483
- response = await self.async_client.chat_completion(
484
- model=fallback,
485
- messages=messages,
486
- max_tokens=max_tokens,
487
- temperature=temperature,
488
- )
489
- return response.choices[0].message.content
490
- except Exception as fallback_err:
491
- print(f"[HF] Fallback {fallback} also failed: {str(fallback_err)[:120]}")
492
- raise fallback_err
493
- raise e
494
-
495
- async def stream_async(
496
- self,
497
- agent_name: str,
498
- system_prompt: str,
499
- user_message: str,
500
- max_tokens: int = None,
501
- temperature: float = None,
502
- ) -> AsyncGenerator[str, None]:
503
- """
504
- Async streaming completion.
505
-
506
- Yields tokens as they are generated.
507
- """
508
- model = self.get_model_for_agent(agent_name)
509
- max_tokens = max_tokens or self.settings.hf.max_new_tokens
510
- temperature = temperature or self.get_temperature_for_agent(agent_name)
511
-
512
- messages = self._build_messages(system_prompt, user_message)
513
-
514
- async for chunk in await self.async_client.chat_completion(
515
- model=model,
516
- messages=messages,
517
- max_tokens=max_tokens,
518
- temperature=temperature,
519
- stream=True,
520
- ):
521
- if chunk.choices[0].delta.content:
522
- yield chunk.choices[0].delta.content
523
-
524
-
525
- # =============================================================================
526
- # SINGLETON & CONVENIENCE FUNCTIONS
527
- # =============================================================================
528
-
529
- _client: Optional[HFInferenceClient] = None
530
-
531
-
532
- def get_inference_client() -> HFInferenceClient:
533
- """Get or create the inference client singleton.
534
-
535
- Re-creates the client if the token has changed (e.g. user entered it
536
- via the Gradio UI after initial startup).
537
- """
538
- global _client
539
- current_token = os.getenv("HF_TOKEN", "")
540
- if _client is None or (_client.token != current_token and current_token):
541
- _client = HFInferenceClient()
542
- return _client
543
-
544
-
545
- def complete(
546
- agent_name: str,
547
- system_prompt: str,
548
- user_message: str,
549
- **kwargs
550
- ) -> str:
551
- """Convenience function for sync completion."""
552
- client = get_inference_client()
553
- return client.complete(agent_name, system_prompt, user_message, **kwargs)
554
-
555
-
556
- async def complete_async(
557
- agent_name: str,
558
- system_prompt: str,
559
- user_message: str,
560
- **kwargs
561
- ) -> str:
562
- """Convenience function for async completion."""
563
- client = get_inference_client()
564
- return await client.complete_async(agent_name, system_prompt, user_message, **kwargs)
565
-
566
-
567
- def get_model_info(model_id: str) -> dict:
568
- """Get information about a specific model."""
569
- if model_id in AVAILABLE_MODELS:
570
- info = AVAILABLE_MODELS[model_id]
571
- return {
572
- "model_id": info.model_id,
573
- "provider": info.provider,
574
- "context_length": info.context_length,
575
- "strengths": info.strengths,
576
- "best_for": info.best_for,
577
- "tier": info.tier,
578
- }
579
- return {"model_id": model_id, "provider": "unknown"}
580
-
581
-
582
- def get_models_by_provider() -> dict[str, list[str]]:
583
- """Get all models grouped by provider."""
584
- by_provider = {}
585
- for model_id, info in AVAILABLE_MODELS.items():
586
- if info.provider not in by_provider:
587
- by_provider[info.provider] = []
588
- by_provider[info.provider].append(model_id)
589
- return by_provider
590
-
591
-
592
- def get_models_by_tier(tier: str) -> list[str]:
593
- """Get all models for a specific tier (free, pro, pro+)."""
594
- return [
595
- model_id for model_id, info in AVAILABLE_MODELS.items()
596
- if info.tier == tier
597
- ]
598
-
599
-
600
- def get_preset_config(preset_name: str) -> dict:
601
- """Get a preset model configuration."""
602
- return MODEL_PRESETS.get(preset_name, MODEL_PRESETS["balanced"])