uskybox commited on
Commit
706c037
·
verified ·
1 Parent(s): 7f049a3

Upload hunter_agent.py

Browse files
Files changed (1) hide show
  1. werewolf/hunter/hunter_agent.py +179 -21
werewolf/hunter/hunter_agent.py CHANGED
@@ -1,34 +1,192 @@
1
- import re
2
  from agent_build_sdk.model.roles import ROLE_HUNTER
3
- from agent_build_sdk.model.werewolf_model import AgentResp, AgentReq, STATUS_START, STATUS_SKILL, STATUS_DISCUSS, STATUS_VOTE, STATUS_HUNTER, STATUS_SHERIFF_ELECTION
 
 
 
 
4
  from agent_build_sdk.sdk.role_agent import BasicRoleAgent
5
  from agent_build_sdk.sdk.agent import format_prompt
6
- from hunter.prompt import GAME_RULE_PROMPT, DESC_PROMPT, SKILL_PROMPT, VOTE_PROMPT
 
 
 
7
 
8
  class HunterAgent(BasicRoleAgent):
 
 
9
  def __init__(self, model_name):
10
  super().__init__(ROLE_HUNTER, model_name=model_name)
 
11
 
12
- def perceive(self, req: AgentReq):
13
  if req.status == STATUS_START:
14
  self.memory.clear()
 
 
15
  self.memory.append_history(GAME_RULE_PROMPT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def interact(self, req: AgentReq) -> AgentResp:
18
- name = req.name if req.name else "我"
19
- if req.status == STATUS_SHERIFF_ELECTION:
20
- return AgentResp(success=True, result="否")
21
-
22
- hist = "\n".join(self.memory.load_history()[-10:])
23
- if req.status in [STATUS_SKILL, STATUS_HUNTER]:
24
- res = self.llm_caller(format_prompt(SKILL_PROMPT, {"history": hist, "name": name, "choices": req.message}))
25
- target = re.search(r"(\d+)", res).group(1) if re.search(r"(\d+)", res) else ""
26
- return AgentResp(success=True, result=res, skillTargetPlayer=target)
27
- elif req.status == STATUS_DISCUSS:
28
- res = self.llm_caller(format_prompt(DESC_PROMPT, {"history": hist, "name": name, "shoot_info": "可以开枪"}))
29
- return AgentResp(success=True, result=res[:200])
30
  elif req.status == STATUS_VOTE:
31
- res = self.llm_caller(format_prompt(VOTE_PROMPT, {"history": hist, "choices": req.message, "name": name}))
32
- target = re.search(r"(\d+)", res).group(1) if re.search(r"(\d+)", res) else req.message.split(",")[0]
33
- return AgentResp(success=True, result=target)
34
- return AgentResp(success=True, result="猎人带队。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from agent_build_sdk.model.roles import ROLE_HUNTER
2
+ from agent_build_sdk.model.werewolf_model import AgentResp, AgentReq, STATUS_START, STATUS_WOLF_SPEECH, \
3
+ STATUS_VOTE_RESULT, STATUS_SKILL, STATUS_SKILL_RESULT, STATUS_NIGHT_INFO, STATUS_DAY, STATUS_DISCUSS, STATUS_VOTE, \
4
+ STATUS_RESULT, STATUS_NIGHT, STATUS_SHERIFF_ELECTION, STATUS_SHERIFF_SPEECH, STATUS_SHERIFF_VOTE, STATUS_SHERIFF, \
5
+ STATUS_SHERIFF_SPEECH_ORDER, STATUS_SHERIFF_PK
6
+ from agent_build_sdk.utils.logger import logger
7
  from agent_build_sdk.sdk.role_agent import BasicRoleAgent
8
  from agent_build_sdk.sdk.agent import format_prompt
9
+ from hunter.prompt import DESC_PROMPT, VOTE_PROMPT, SKILL_PROMPT, GAME_RULE_PROMPT, CLEAN_USER_PROMPT, \
10
+ SHERIFF_ELECTION_PROMPT, SHERIFF_SPEECH_PROMPT, SHERIFF_VOTE_PROMPT, SHERIFF_SPEECH_ORDER_PROMPT, \
11
+ SHERIFF_TRANSFER_PROMPT
12
+
13
 
14
  class HunterAgent(BasicRoleAgent):
15
+ """猎人角色Agent"""
16
+
17
  def __init__(self, model_name):
18
  super().__init__(ROLE_HUNTER, model_name=model_name)
19
+ self.memory.set_variable("can_shoot", True) # 猎人初始可以开枪
20
 
21
+ def perceive(self, req=AgentReq):
22
  if req.status == STATUS_START:
23
  self.memory.clear()
24
+ self.memory.set_variable("name", req.name)
25
+ self.memory.set_variable("can_shoot", True)
26
  self.memory.append_history(GAME_RULE_PROMPT)
27
+ self.memory.append_history(f"主持人:你好,你分配到的角色是[猎人],你是{req.name}")
28
+ elif req.status == STATUS_NIGHT:
29
+ self.memory.append_history("主持人:现在进入夜晚,天黑请闭眼")
30
+ elif req.status == STATUS_SKILL_RESULT:
31
+ self.memory.append_history(f"主持人:{req.message}")
32
+ # 根据技能结果更新开枪状态
33
+ if "能开枪" in req.message:
34
+ self.memory.set_variable("can_shoot", True)
35
+ elif "不能开枪" in req.message:
36
+ self.memory.set_variable("can_shoot", False)
37
+ elif req.status == STATUS_NIGHT_INFO:
38
+ self.memory.append_history(f"主持人:天亮了!昨天晚上的信息是: {req.message}")
39
+ elif req.status == STATUS_DISCUSS: # 发言环节
40
+ if req.name:
41
+ # 其他玩家发言
42
+ # 可以使用模型来过滤掉玩家的注入消息,也可以换一个小模型,实际使用需要考虑对memory加锁,避免interact的时候丢失消息
43
+ # clean_user_message_prompt = format_prompt(CLEAN_USER_PROMPT, {"user_message": req.message})
44
+ # req.message = self.llm_caller(clean_user_message_prompt)
45
+ self.memory.append_history(req.name + ': ' + req.message)
46
+ else:
47
+ # 主持人发言
48
+ self.memory.append_history('主持人: 现在进入第{}天。'.format(str(req.round)))
49
+ self.memory.append_history('主持人: 每个玩家描述自己的信息。')
50
+ self.memory.append_history("---------------------------------------------")
51
+ elif req.status == STATUS_VOTE: # 投票环节
52
+ self.memory.append_history(f'第{req.round}天。投票信息:{req.name}投了{req.message}')
53
+ elif req.status == STATUS_VOTE_RESULT: # 投票环节
54
+ if req.name:
55
+ self.memory.append_history('主持人: 投票结果是:{}。'.format(req.name))
56
+ else:
57
+ self.memory.append_history('主持人: 无人出局。')
58
+ elif req.status == STATUS_SHERIFF_ELECTION:
59
+ self.memory.append_history(f"主持人: 上警玩家: {req.message}")
60
+ elif req.status == STATUS_SHERIFF_SPEECH:
61
+ self.memory.append_history(f"{req.name} (警上发言): {req.message}")
62
+ elif req.status == STATUS_SHERIFF_VOTE:
63
+ self.memory.append_history(f"警上投票: {req.name}投了{req.message}")
64
+ elif req.status == STATUS_SHERIFF:
65
+ if req.name:
66
+ self.memory.append_history(f"主持人: 警徽归属: {req.name}")
67
+ self.memory.set_variable("sheriff", req.name)
68
+ if req.message:
69
+ self.memory.append_history(req.message)
70
+ elif req.status == STATUS_RESULT:
71
+ self.memory.append_history(req.message)
72
+ elif req.status == STATUS_SHERIFF_SPEECH_ORDER:
73
+ if "小号" in req.message:
74
+ self.memory.append_history("主持人: 警长发言顺序是小号优先")
75
+ else:
76
+ self.memory.append_history("主持人: 警长发言顺序是大号优先")
77
+ elif req.status == STATUS_SHERIFF_PK:
78
+ self.memory.append_history(f"警长PK发言: {req.name}: {req.message}")
79
+ else:
80
+ raise NotImplementedError
81
+
82
+ def interact(self, req=AgentReq) -> AgentResp:
83
+ logger.info("hunter interact: {}".format(req))
84
+ if req.status == STATUS_DISCUSS:
85
+ if req.message:
86
+ self.memory.append_history(req.message)
87
+ can_shoot = self.memory.load_variable("can_shoot")
88
+ shoot_info = "可以开枪" if can_shoot else "不能开枪"
89
+ prompt = format_prompt(DESC_PROMPT,
90
+ {"name": self.memory.load_variable("name"),
91
+ "shoot_info": shoot_info,
92
+ "history": "\n".join(self.memory.load_history())
93
+ })
94
+ logger.info("prompt:" + prompt)
95
+ result = self.llm_caller(prompt)
96
+ logger.info("hunter interact result: {}".format(result))
97
+ return AgentResp(success=True, result=result, errMsg=None)
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  elif req.status == STATUS_VOTE:
100
+ self.memory.append_history('主持人: 到了投票的时候了。每个人,请指向你认为可能是狼人的人。')
101
+ choices = [name for name in req.message.split(",") if name != self.memory.load_variable("name")]
102
+ self.memory.set_variable("choices", choices)
103
+ prompt = format_prompt(VOTE_PROMPT, {"name": self.memory.load_variable("name"),
104
+ "choices": choices,
105
+ "history": "\n".join(self.memory.load_history())
106
+ })
107
+ logger.info("prompt:" + prompt)
108
+ result = self.llm_caller(prompt)
109
+ logger.info("hunter interact result: {}".format(result))
110
+ return AgentResp(success=True, result=result, errMsg=None)
111
+
112
+ elif req.status == STATUS_SKILL:
113
+ # 猎人技能:开枪射杀一名玩家(遗言阶段)
114
+ can_shoot = self.memory.load_variable("can_shoot")
115
+ if not can_shoot:
116
+ return AgentResp(success=True, result="不开枪", errMsg=None)
117
+
118
+ choices = [name for name in req.message.split(",") if name != self.memory.load_variable("name")]
119
+ prompt = format_prompt(SKILL_PROMPT, {
120
+ "name": self.memory.load_variable("name"),
121
+ "choices": choices,
122
+ "history": "\n".join(self.memory.load_history())
123
+ })
124
+ logger.info("prompt:" + prompt)
125
+ result = self.llm_caller(prompt)
126
+ logger.info("hunter skill result: {}".format(result))
127
+
128
+ if result != "不开枪":
129
+ self.memory.set_variable("can_shoot", False)
130
+
131
+ return AgentResp(success=True, result=result, skillTargetPlayer=None if result == "不开枪" else result, errMsg=None)
132
+
133
+ elif req.status == STATUS_SHERIFF_ELECTION:
134
+ can_shoot = self.memory.load_variable("can_shoot")
135
+ shoot_info = "可以开枪" if can_shoot else "不能开枪"
136
+ prompt = format_prompt(SHERIFF_ELECTION_PROMPT,
137
+ {"name": self.memory.load_variable("name"),
138
+ "shoot_info": shoot_info,
139
+ "history": "\n".join(self.memory.load_history())
140
+ })
141
+ logger.info("prompt:" + prompt)
142
+ result = self.llm_caller(prompt)
143
+ return AgentResp(success=True, result=result, errMsg=None)
144
+
145
+ elif req.status == STATUS_SHERIFF_SPEECH or req.status == STATUS_SHERIFF_PK:
146
+ can_shoot = self.memory.load_variable("can_shoot")
147
+ shoot_info = "可以开枪" if can_shoot else "不能开枪"
148
+ prompt = format_prompt(SHERIFF_SPEECH_PROMPT,
149
+ {"name": self.memory.load_variable("name"),
150
+ "shoot_info": shoot_info,
151
+ "history": "\n".join(self.memory.load_history())
152
+ })
153
+ logger.info("prompt:" + prompt)
154
+ result = self.llm_caller(prompt)
155
+ return AgentResp(success=True, result=result, errMsg=None)
156
+
157
+ elif req.status == STATUS_SHERIFF_VOTE:
158
+ choices = req.message.split(",")
159
+ prompt = format_prompt(SHERIFF_VOTE_PROMPT,
160
+ {"name": self.memory.load_variable("name"),
161
+ "choices": choices,
162
+ "history": "\n".join(self.memory.load_history())
163
+ })
164
+ logger.info("prompt:" + prompt)
165
+ result = self.llm_caller(prompt)
166
+ return AgentResp(success=True, result=result, errMsg=None)
167
+
168
+ elif req.status == STATUS_SHERIFF_SPEECH_ORDER:
169
+ prompt = format_prompt(SHERIFF_SPEECH_ORDER_PROMPT,
170
+ {"name": self.memory.load_variable("name"),
171
+ "history": "\n".join(self.memory.load_history())
172
+ })
173
+ logger.info("prompt:" + prompt)
174
+ result = self.llm_caller(prompt)
175
+ return AgentResp(success=True, result=result, errMsg=None)
176
+
177
+ elif req.status == STATUS_SHERIFF:
178
+ # 警长转移警徽
179
+ can_shoot = self.memory.load_variable("can_shoot")
180
+ shoot_info = "可以开枪" if can_shoot else "不能开枪"
181
+ choices = [name for name in req.message.split(",") if name != self.memory.load_variable("name")]
182
+ prompt = format_prompt(SHERIFF_TRANSFER_PROMPT,
183
+ {"name": self.memory.load_variable("name"),
184
+ "shoot_info": shoot_info,
185
+ "choices": choices,
186
+ "history": "\n".join(self.memory.load_history())
187
+ })
188
+ logger.info("prompt:" + prompt)
189
+ result = self.llm_caller(prompt)
190
+ return AgentResp(success=True, result=result, errMsg=None)
191
+ else:
192
+ raise NotImplementedError