File size: 5,732 Bytes
77320e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import unittest
from datetime import datetime

from src.db.conversation_do import ConversationDO, SandboxStatus, ConversationStatus, MessageConf
from src.schemas import RoleType


class TestConversationDO(unittest.TestCase):

    def setUp(self):
        # Sample data for testing with the new fields
        self.sample_data = {
            "conversation_id": "12345",
            "messages": [{"role": "user", "content": "Hello"}],
            "input_files": [{"name": "file1.txt"}],
            "output_files": [{"name": "file2.txt"}],
            "sandbox_id": "sandbox_sample",
            "sandbox_status": "RUNNING",
            "create_time": datetime.utcnow(),
            "update_time": datetime.utcnow(),
            "user": "test_user_v3",
            "model_name": "sample_model",
            "model_conf_path": "path/to/conf",
            "llm_name": "sample_llm",
            "agent_type": "sample_agent",
            "request_id": "request_sample",
            "dead_sandbox_ids": ["dead1", "dead2"],
            "status": "RUNNING"
        }

    def test_init(self):
        conversation = ConversationDO(**self.sample_data)
        self.assertEqual(conversation.conversation_id, "12345")
        self.assertEqual(len(conversation.messages), 1)
        self.assertEqual(conversation.messages[0].role, RoleType.User)
        self.assertEqual(conversation.messages[0].content, "Hello")

    def test_to_dict(self):
        conversation = ConversationDO(**self.sample_data)
        data = conversation.to_dict()
        self.assertEqual(data["conversation_id"], "12345")
        self.assertEqual(data["sandbox_status"], "RUNNING")
        self.assertEqual(len(data["messages"]), 1)
        self.assertEqual(data["messages"][0]["role"], 0)
        self.assertEqual(data["messages"][0]["content"], "Hello")

    def test_from_dict(self):
        data = self.sample_data.copy()
        conversation = ConversationDO.from_dict(data)
        self.assertEqual(conversation.conversation_id, "12345")
        self.assertEqual(len(conversation.messages), 1)
        self.assertEqual(conversation.messages[0].role, RoleType.User)
        self.assertEqual(conversation.messages[0].content, "Hello")

    def test_update(self):
        conversation = ConversationDO(**self.sample_data)
        updated_data = {
            "sandbox_status": "KILLED",
            "user": "updateduser",
            "messages": [{"role": "Agent", "content": "Hi"}]
        }
        conversation.update(updated_data)
        self.assertEqual(conversation.sandbox_status, SandboxStatus.KILLED)
        self.assertEqual(conversation.user, "updateduser")
        self.assertEqual(len(conversation.messages), 1)
        self.assertEqual(conversation.messages[0].role, RoleType.Agent)
        self.assertEqual(conversation.messages[0].content, "Hi")

    def test_invalid_status(self):
        # Testing invalid sandbox_status
        data = self.sample_data.copy()
        data["sandbox_status"] = "INVALID_STATUS"
        conversation = ConversationDO(**data)
        self.assertEqual(conversation.sandbox_status, SandboxStatus.UNKNOWN)

        # Testing invalid conversation status
        data["status"] = "INVALID_STATUS"
        conversation = ConversationDO.from_dict(data)
        self.assertEqual(conversation.status, ConversationStatus.UNKNOWN)

    def test_is_in_running_status(self):
        conversation = ConversationDO(**self.sample_data)
        self.assertTrue(conversation.is_in_running_status())
        conversation.status = ConversationStatus.COMPLETED
        self.assertFalse(conversation.is_in_running_status())

    def test_message_conf(self):
        # Test initialization with a MessageConf object
        conf = MessageConf(temperature=0.9, top_p=0.8, top_k=5)
        conversation = ConversationDO(message_conf=conf, **self.sample_data)
        self.assertEqual(conversation.message_conf.__temperature, 0.9)
        self.assertEqual(conversation.message_conf.__top_p, 0.8)
        self.assertEqual(conversation.message_conf.__top_k, 5)

        # Test initialization with a dictionary
        conf_dict = {"temperature": 0.9, "top_p": 0.8, "top_k": 5}
        conversation = ConversationDO(message_conf=conf_dict, **self.sample_data)
        self.assertEqual(conversation.message_conf.__temperature, 0.9)
        self.assertEqual(conversation.message_conf.__top_p, 0.8)
        self.assertEqual(conversation.message_conf.__top_k, 5)

    def test_to_dict_with_message_conf(self):
        conf = MessageConf(temperature=0.9, top_p=0.8, top_k=5)
        conversation = ConversationDO(message_conf=conf, **self.sample_data)
        data = conversation.to_dict()
        self.assertEqual(data["message_conf"]["temperature"], 0.9)
        self.assertEqual(data["message_conf"]["top_p"], 0.8)
        self.assertEqual(data["message_conf"]["top_k"], 5)

    def test_from_dict_with_message_conf(self):
        data = self.sample_data.copy()
        data["message_conf"] = {"temperature": 0.9, "top_p": 0.8, "top_k": 5}
        conversation = ConversationDO.from_dict(data)
        self.assertEqual(conversation.message_conf.__temperature, 0.9)
        self.assertEqual(conversation.message_conf.__top_p, 0.8)
        self.assertEqual(conversation.message_conf.__top_k, 5)

    def test_update_with_message_conf(self):
        conversation = ConversationDO(**self.sample_data)
        updated_data = {
            "message_conf": {"temperature": 0.9, "top_p": 0.8, "top_k": 5}
        }
        conversation.update(updated_data)
        self.assertEqual(conversation.message_conf.__temperature, 0.9)
        self.assertEqual(conversation.message_conf.__top_p, 0.8)
        self.assertEqual(conversation.message_conf.__top_k, 5)