|
|
import unittest |
|
|
from datetime import datetime |
|
|
|
|
|
from src.db.conversation_do import ConversationDO, SandboxStatus, ConversationStatus, MessageConf |
|
|
from src.schemas import RoleType |
|
|
|
|
|
|
|
|
class TestConversationDO(unittest.TestCase): |
|
|
|
|
|
def setUp(self): |
|
|
|
|
|
self.sample_data = { |
|
|
"conversation_id": "12345", |
|
|
"messages": [{"role": "user", "content": "Hello"}], |
|
|
"input_files": [{"name": "file1.txt"}], |
|
|
"output_files": [{"name": "file2.txt"}], |
|
|
"sandbox_id": "sandbox_sample", |
|
|
"sandbox_status": "RUNNING", |
|
|
"create_time": datetime.utcnow(), |
|
|
"update_time": datetime.utcnow(), |
|
|
"user": "test_user_v3", |
|
|
"model_name": "sample_model", |
|
|
"model_conf_path": "path/to/conf", |
|
|
"llm_name": "sample_llm", |
|
|
"agent_type": "sample_agent", |
|
|
"request_id": "request_sample", |
|
|
"dead_sandbox_ids": ["dead1", "dead2"], |
|
|
"status": "RUNNING" |
|
|
} |
|
|
|
|
|
def test_init(self): |
|
|
conversation = ConversationDO(**self.sample_data) |
|
|
self.assertEqual(conversation.conversation_id, "12345") |
|
|
self.assertEqual(len(conversation.messages), 1) |
|
|
self.assertEqual(conversation.messages[0].role, RoleType.User) |
|
|
self.assertEqual(conversation.messages[0].content, "Hello") |
|
|
|
|
|
def test_to_dict(self): |
|
|
conversation = ConversationDO(**self.sample_data) |
|
|
data = conversation.to_dict() |
|
|
self.assertEqual(data["conversation_id"], "12345") |
|
|
self.assertEqual(data["sandbox_status"], "RUNNING") |
|
|
self.assertEqual(len(data["messages"]), 1) |
|
|
self.assertEqual(data["messages"][0]["role"], 0) |
|
|
self.assertEqual(data["messages"][0]["content"], "Hello") |
|
|
|
|
|
def test_from_dict(self): |
|
|
data = self.sample_data.copy() |
|
|
conversation = ConversationDO.from_dict(data) |
|
|
self.assertEqual(conversation.conversation_id, "12345") |
|
|
self.assertEqual(len(conversation.messages), 1) |
|
|
self.assertEqual(conversation.messages[0].role, RoleType.User) |
|
|
self.assertEqual(conversation.messages[0].content, "Hello") |
|
|
|
|
|
def test_update(self): |
|
|
conversation = ConversationDO(**self.sample_data) |
|
|
updated_data = { |
|
|
"sandbox_status": "KILLED", |
|
|
"user": "updateduser", |
|
|
"messages": [{"role": "Agent", "content": "Hi"}] |
|
|
} |
|
|
conversation.update(updated_data) |
|
|
self.assertEqual(conversation.sandbox_status, SandboxStatus.KILLED) |
|
|
self.assertEqual(conversation.user, "updateduser") |
|
|
self.assertEqual(len(conversation.messages), 1) |
|
|
self.assertEqual(conversation.messages[0].role, RoleType.Agent) |
|
|
self.assertEqual(conversation.messages[0].content, "Hi") |
|
|
|
|
|
def test_invalid_status(self): |
|
|
|
|
|
data = self.sample_data.copy() |
|
|
data["sandbox_status"] = "INVALID_STATUS" |
|
|
conversation = ConversationDO(**data) |
|
|
self.assertEqual(conversation.sandbox_status, SandboxStatus.UNKNOWN) |
|
|
|
|
|
|
|
|
data["status"] = "INVALID_STATUS" |
|
|
conversation = ConversationDO.from_dict(data) |
|
|
self.assertEqual(conversation.status, ConversationStatus.UNKNOWN) |
|
|
|
|
|
def test_is_in_running_status(self): |
|
|
conversation = ConversationDO(**self.sample_data) |
|
|
self.assertTrue(conversation.is_in_running_status()) |
|
|
conversation.status = ConversationStatus.COMPLETED |
|
|
self.assertFalse(conversation.is_in_running_status()) |
|
|
|
|
|
def test_message_conf(self): |
|
|
|
|
|
conf = MessageConf(temperature=0.9, top_p=0.8, top_k=5) |
|
|
conversation = ConversationDO(message_conf=conf, **self.sample_data) |
|
|
self.assertEqual(conversation.message_conf.__temperature, 0.9) |
|
|
self.assertEqual(conversation.message_conf.__top_p, 0.8) |
|
|
self.assertEqual(conversation.message_conf.__top_k, 5) |
|
|
|
|
|
|
|
|
conf_dict = {"temperature": 0.9, "top_p": 0.8, "top_k": 5} |
|
|
conversation = ConversationDO(message_conf=conf_dict, **self.sample_data) |
|
|
self.assertEqual(conversation.message_conf.__temperature, 0.9) |
|
|
self.assertEqual(conversation.message_conf.__top_p, 0.8) |
|
|
self.assertEqual(conversation.message_conf.__top_k, 5) |
|
|
|
|
|
def test_to_dict_with_message_conf(self): |
|
|
conf = MessageConf(temperature=0.9, top_p=0.8, top_k=5) |
|
|
conversation = ConversationDO(message_conf=conf, **self.sample_data) |
|
|
data = conversation.to_dict() |
|
|
self.assertEqual(data["message_conf"]["temperature"], 0.9) |
|
|
self.assertEqual(data["message_conf"]["top_p"], 0.8) |
|
|
self.assertEqual(data["message_conf"]["top_k"], 5) |
|
|
|
|
|
def test_from_dict_with_message_conf(self): |
|
|
data = self.sample_data.copy() |
|
|
data["message_conf"] = {"temperature": 0.9, "top_p": 0.8, "top_k": 5} |
|
|
conversation = ConversationDO.from_dict(data) |
|
|
self.assertEqual(conversation.message_conf.__temperature, 0.9) |
|
|
self.assertEqual(conversation.message_conf.__top_p, 0.8) |
|
|
self.assertEqual(conversation.message_conf.__top_k, 5) |
|
|
|
|
|
def test_update_with_message_conf(self): |
|
|
conversation = ConversationDO(**self.sample_data) |
|
|
updated_data = { |
|
|
"message_conf": {"temperature": 0.9, "top_p": 0.8, "top_k": 5} |
|
|
} |
|
|
conversation.update(updated_data) |
|
|
self.assertEqual(conversation.message_conf.__temperature, 0.9) |
|
|
self.assertEqual(conversation.message_conf.__top_p, 0.8) |
|
|
self.assertEqual(conversation.message_conf.__top_k, 5) |
|
|
|