File size: 3,221 Bytes
1905805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from openai import AsyncOpenAI

from .base import T2TOperation
from utils.prompter.message import ChatMessage
from utils.prompter import Prompter

class OpenAIT2T(T2TOperation):
    def __init__(self):
        super().__init__("openai")
        self.client = None
        
        self.base_url = "https://api.openai.com/v1/"
        self.model = "gpt-4o"
        self.temperature = 1
        self.top_p = 0.9
        self.presence_penalty = 0
        self.frequency_penalty = 0
        
    async def start(self):
        await super().start()
        self.client = AsyncOpenAI(base_url=self.base_url)
        
    async def close(self):
        await super().close()
        await self.client.close()
        self.client = None
        
    async def configure(self, config_d):
        '''Configure and validate operation-specific configuration'''
        if "base_url" in config_d: self.base_url = str(config_d['base_url'])
        if "model" in config_d: self.model = str(config_d['model'])

        if "temperature" in config_d: self.temperature = float(config_d['temperature'])
        if "top_p" in config_d: self.top_p = float(config_d['top_p'])
        if "presence_penalty" in config_d: self.presence_penalty = float(config_d['presence_penalty'])
        if "frequency_penalty" in config_d: self.frequency_penalty = float(config_d['frequency_penalty'])
        
        assert self.base_url is not None and len(self.base_url) > 0
        assert self.model is not None and len(self.model) > 0
        assert self.temperature >= 0 and self.temperature <= 2
        assert self.top_p >= 0 and self.top_p <= 1
        assert self.presence_penalty >= 0 and self.presence_penalty <= 1
        assert self.frequency_penalty >= 0 and self.frequency_penalty <= 1
        
    async def get_configuration(self):
        '''Returns values of configurable fields'''
        return {
            "base_url": self.base_url,
            "model": self.model,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "presence_penalty": self.presence_penalty,
            "frequency_penalty": self.frequency_penalty,
        }

    async def _generate(self, instruction_prompt: str = None, messages: list = None, **kwargs):
        history = [{ "role": "system", "content": instruction_prompt }]
        for msg in messages:
            next_hist = None
            if isinstance(msg, ChatMessage) and msg.user == Prompter().character_name:
                next_hist = { "role": "assistant", "content": msg.message }
            else:
                next_hist = { "role": "user", "content": msg.to_line() }
            history.append(next_hist)

        stream = await self.client.chat.completions.create(
            messages=history,
            model=self.model,
            stream=True,
            temperature=self.temperature,
            top_p=self.top_p,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty
        )

        full_response = ""
        async for chunk in stream:
            content_chunk = chunk.choices[0].delta.content or ""
            full_response += content_chunk
            yield {"content": content_chunk}