luisejdm commited on
Commit
2de1d47
·
1 Parent(s): e309431

Add model file

Browse files
Files changed (1) hide show
  1. model.py +31 -0
model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from openai import OpenAI
6
+
7
+
8
+ @dataclass
9
+ class ModelRunner:
10
+ model_name: str
11
+ _client: OpenAI
12
+
13
+ @classmethod
14
+ def load(cls, model_name: str, api_key: str, base_url: str) -> "ModelRunner":
15
+ print(f"Connecting to {model_name} via API ...")
16
+ client = OpenAI(base_url=base_url, api_key=api_key)
17
+ return cls(model_name=model_name, _client=client)
18
+
19
+ def generate(
20
+ self,
21
+ messages: list[dict[str, str]],
22
+ temperature: float = 0.3,
23
+ max_new_tokens: int = 512,
24
+ ) -> str:
25
+ response = self._client.chat.completions.create(
26
+ model=self.model_name,
27
+ messages=messages,
28
+ temperature=temperature,
29
+ max_tokens=max_new_tokens,
30
+ )
31
+ return response.choices[0].message.content