| |
| from dataclasses import dataclass |
| from typing import List, Optional |
|
|
| @dataclass |
| class ModelInfo: |
| """ |
| Represents metadata for an inference model. |
| |
| Attributes: |
| name: Human-readable name of the model. |
| id: Unique model identifier (HF/externally routed). |
| description: Short description of the model's capabilities. |
| default_provider: Preferred inference provider ("auto", "groq", "openai", "gemini", "fireworks"). |
| """ |
| name: str |
| id: str |
| description: str |
| default_provider: str = "auto" |
|
|
| |
| AVAILABLE_MODELS: List[ModelInfo] = [ |
| ModelInfo( |
| name="Moonshot Kimi-K2", |
| id="moonshotai/Kimi-K2-Instruct", |
| description="Moonshot AI Kimi-K2-Instruct model for code generation and general tasks", |
| default_provider="groq" |
| ), |
| ModelInfo( |
| name="DeepSeek V3", |
| id="deepseek-ai/DeepSeek-V3-0324", |
| description="DeepSeek V3 model for code generation", |
| ), |
| ModelInfo( |
| name="DeepSeek R1", |
| id="deepseek-ai/DeepSeek-R1-0528", |
| description="DeepSeek R1 model for code generation", |
| ), |
| ModelInfo( |
| name="ERNIE-4.5-VL", |
| id="baidu/ERNIE-4.5-VL-424B-A47B-Base-PT", |
| description="ERNIE-4.5-VL model for multimodal code generation with image support", |
| ), |
| ModelInfo( |
| name="MiniMax M1", |
| id="MiniMaxAI/MiniMax-M1-80k", |
| description="MiniMax M1 model for code generation and general tasks", |
| ), |
| ModelInfo( |
| name="Qwen3-235B-A22B", |
| id="Qwen/Qwen3-235B-A22B", |
| description="Qwen3-235B-A22B model for code generation and general tasks", |
| ), |
| ModelInfo( |
| name="SmolLM3-3B", |
| id="HuggingFaceTB/SmolLM3-3B", |
| description="SmolLM3-3B model for code generation and general tasks", |
| ), |
| ModelInfo( |
| name="GLM-4.1V-9B-Thinking", |
| id="THUDM/GLM-4.1V-9B-Thinking", |
| description="GLM-4.1V-9B-Thinking model for multimodal code generation with image support", |
| ), |
| ModelInfo( |
| name="OpenAI GPT-4", |
| id="openai/gpt-4", |
| description="OpenAI GPT-4 model via HF Inference Providers", |
| default_provider="openai" |
| ), |
| ModelInfo( |
| name="Gemini Pro", |
| id="gemini/pro", |
| description="Google Gemini Pro model via HF Inference Providers", |
| default_provider="gemini" |
| ), |
| ModelInfo( |
| name="Fireworks AI", |
| id="fireworks-ai/fireworks-v1", |
| description="Fireworks AI model via HF Inference Providers", |
| default_provider="fireworks" |
| ), |
| ] |
|
|
|
|
| def find_model(identifier: str) -> Optional[ModelInfo]: |
| """ |
| Lookup a model by its human name or identifier. |
| |
| Args: |
| identifier: ModelInfo.name (case-insensitive) or ModelInfo.id |
| Returns: |
| The matching ModelInfo or None if not found. |
| """ |
| identifier_lower = identifier.lower() |
| for model in AVAILABLE_MODELS: |
| if model.id == identifier or model.name.lower() == identifier_lower: |
| return model |
| return None |
|
|
|
|
| |
| from typing import List, Dict |
| from hf_client import get_inference_client |
|
|
|
|
| def chat_completion( |
| model_id: str, |
| messages: List[Dict[str, str]], |
| provider: str = None, |
| max_tokens: int = 4096 |
| ) -> str: |
| """ |
| Send a chat completion request to the appropriate inference provider. |
| |
| Args: |
| model_id: The model identifier to use. |
| messages: A list of OpenAI-style {'role','content'} messages. |
| provider: Optional override for provider; uses model default if None. |
| max_tokens: Maximum tokens to generate. |
| |
| Returns: |
| The assistant's response content. |
| """ |
| |
| client = get_inference_client(model_id, provider or "auto") |
| response = client.chat.completions.create( |
| model=model_id, |
| messages=messages, |
| max_tokens=max_tokens |
| ) |
| |
| return response.choices[0].message.content |
|
|
|
|
| def stream_chat_completion( |
| model_id: str, |
| messages: List[Dict[str, str]], |
| provider: str = None, |
| max_tokens: int = 4096 |
| ): |
| """ |
| Generator for streaming chat completions. |
| Yields partial message chunks as strings. |
| """ |
| client = get_inference_client(model_id, provider or "auto") |
| stream = client.chat.completions.create( |
| model=model_id, |
| messages=messages, |
| max_tokens=max_tokens, |
| stream=True |
| ) |
| for chunk in stream: |
| delta = getattr(chunk.choices[0].delta, "content", None) |
| if delta: |
| yield delta |
|
|