File size: 860 Bytes
2de1d47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from __future__ import annotations

from dataclasses import dataclass

from openai import OpenAI


@dataclass
class ModelRunner:
    model_name: str
    _client: OpenAI

    @classmethod
    def load(cls, model_name: str, api_key: str, base_url: str) -> "ModelRunner":
        print(f"Connecting to {model_name} via API ...")
        client = OpenAI(base_url=base_url, api_key=api_key)
        return cls(model_name=model_name, _client=client)

    def generate(
        self,
        messages: list[dict[str, str]],
        temperature: float = 0.3,
        max_new_tokens: int = 512,
    ) -> str:
        response = self._client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            temperature=temperature,
            max_tokens=max_new_tokens,
        )
        return response.choices[0].message.content