| from dataclasses import dataclass, field |
| from typing import List, Optional |
|
|
| @dataclass |
| class ModelInfo: |
| size: str |
| name: str |
| max_model_len: int |
| is_chat: bool |
| is_multimodal: bool = False |
| image_placeholder: Optional[str] = None |
| mm_data_key: Optional[str] = None |
|
|
| @dataclass |
| class InferenceResult: |
| prompt: str |
| text: str |
| token_ids: List[int] = field(default_factory=list) |
| |
| logprobs: Optional[List[float]] = None |
|
|
| class BaseModel: |
| def info(self) -> ModelInfo: |
| raise NotImplementedError |
|
|
| def infer( |
| self, |
| prompts: List[str], |
| *, |
| max_tokens: int = 256, |
| temperature: float = 0.7, |
| top_p: float = 0.95, |
| logprobs: Optional[int] = None, |
| stop: Optional[List[str]] = None, |
| ) -> List[InferenceResult]: |
| raise NotImplementedError |
|
|
| |
| class _DummyModel(BaseModel): |
| def __init__(self, size: str): |
| self._info = ModelInfo( |
| size=size, |
| name=f"dummy-{size}", |
| max_model_len=4096, |
| is_chat=False, |
| ) |
|
|
| def info(self) -> ModelInfo: |
| return self._info |
|
|
| def infer(self, prompts: List[str], **kwargs) -> List[InferenceResult]: |
| outs: List[InferenceResult] = [] |
| for p in prompts: |
| outs.append(InferenceResult( |
| prompt=p, |
| text=f"[dummy completion for]: {p}", |
| token_ids=[0, 1, 2], |
| logprobs=None |
| )) |
| return outs |
|
|
| def get_model(size: str) -> BaseModel: |
| if size not in {"small", "medium", "large"}: |
| raise ValueError("size must be one of {'small','medium','large'}") |
| return _DummyModel(size) |
|
|