Spaces:
Sleeping
Sleeping
Commit
·
97571ce
1
Parent(s):
16a6025
feat: rename ntu img to ruby
Browse files- utils/completion_reward.py +38 -7
utils/completion_reward.py
CHANGED
|
@@ -55,6 +55,11 @@ class CompletionReward:
|
|
| 55 |
self.google_agent = GoogleAgent()
|
| 56 |
self.mtk_agent = MTKAgent()
|
| 57 |
self.ntu_agent = NTUAgent()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
self.agents_responses = {}
|
| 59 |
self.agent_list = [
|
| 60 |
self.openai_agent,
|
|
@@ -154,6 +159,11 @@ class CompletionReward:
|
|
| 154 |
"paragraph_google": self.paragraph_google,
|
| 155 |
"paragraph_mtk": self.paragraph_mtk,
|
| 156 |
"paragraph_ntu": self.paragraph_ntu,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
"player_certificate_url": self.player_certificate_url,
|
| 158 |
"created_at_date": datetime.now().date(),
|
| 159 |
}
|
|
@@ -250,6 +260,7 @@ class OpenAIAgent:
|
|
| 250 |
|
| 251 |
retry_attempts = 0
|
| 252 |
while retry_attempts < 5:
|
|
|
|
| 253 |
try:
|
| 254 |
response = client.chat.completions.create(
|
| 255 |
model="gpt-4-1106-preview",
|
|
@@ -260,6 +271,7 @@ class OpenAIAgent:
|
|
| 260 |
presence_penalty=self.presence_penalty,
|
| 261 |
)
|
| 262 |
chinese_converter = OpenCC("s2tw")
|
|
|
|
| 263 |
return chinese_converter.convert(response.choices[0].message.content)
|
| 264 |
|
| 265 |
except Exception as e:
|
|
@@ -267,6 +279,7 @@ class OpenAIAgent:
|
|
| 267 |
logging.error(f"OpenAI Attempt {retry_attempts}: {e}")
|
| 268 |
time.sleep(1 * retry_attempts)
|
| 269 |
|
|
|
|
| 270 |
return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
|
| 271 |
|
| 272 |
def get_background(self):
|
|
@@ -321,12 +334,15 @@ class AWSAgent:
|
|
| 321 |
retry_attempts = 0
|
| 322 |
while retry_attempts < 5:
|
| 323 |
try:
|
|
|
|
| 324 |
completion = client.completions.create(
|
| 325 |
model="anthropic.claude-v2",
|
| 326 |
max_tokens_to_sample=2048,
|
| 327 |
prompt=f"{anthropic_bedrock.HUMAN_PROMPT}{system_prompt},以下是我的故事紀錄```{user_log}``` {anthropic_bedrock.AI_PROMPT}",
|
| 328 |
)
|
| 329 |
chinese_converter = OpenCC("s2tw")
|
|
|
|
|
|
|
| 330 |
return chinese_converter.convert(completion.completion)
|
| 331 |
|
| 332 |
except Exception as e:
|
|
@@ -334,6 +350,7 @@ class AWSAgent:
|
|
| 334 |
logging.error(f"AWS Attempt {retry_attempts}: {e}")
|
| 335 |
time.sleep(1 * retry_attempts)
|
| 336 |
|
|
|
|
| 337 |
return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
|
| 338 |
|
| 339 |
|
|
@@ -377,13 +394,14 @@ class GoogleAgent:
|
|
| 377 |
retry_attempts = 0
|
| 378 |
while retry_attempts < 5:
|
| 379 |
try:
|
|
|
|
| 380 |
logging.info("Google Generating response...")
|
| 381 |
model_response = self.gemini_pro_model.generate_content(
|
| 382 |
f"{system_prompt}, 以下是我的冒險故事 ```{user_log}```"
|
| 383 |
)
|
| 384 |
|
| 385 |
chinese_converter = OpenCC("s2tw")
|
| 386 |
-
|
| 387 |
return chinese_converter.convert(
|
| 388 |
model_response.candidates[0].content.parts[0].text
|
| 389 |
)
|
|
@@ -393,6 +411,7 @@ class GoogleAgent:
|
|
| 393 |
logging.error(f"Google Attempt {retry_attempts}: {e}")
|
| 394 |
time.sleep(1 * retry_attempts)
|
| 395 |
|
|
|
|
| 396 |
return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
|
| 397 |
|
| 398 |
|
|
@@ -447,6 +466,7 @@ class MTKAgent:
|
|
| 447 |
retry_attempts = 0
|
| 448 |
while retry_attempts < 5:
|
| 449 |
try:
|
|
|
|
| 450 |
response = requests.post(
|
| 451 |
url, headers=headers, data=json.dumps(data)
|
| 452 |
).json()
|
|
@@ -458,7 +478,7 @@ class MTKAgent:
|
|
| 458 |
extracted_content = "\n".join(matched_contents).strip()
|
| 459 |
|
| 460 |
chinese_converter = OpenCC("s2tw")
|
| 461 |
-
|
| 462 |
if extracted_content:
|
| 463 |
return chinese_converter.convert(extracted_content)
|
| 464 |
else:
|
|
@@ -468,7 +488,8 @@ class MTKAgent:
|
|
| 468 |
retry_attempts += 1
|
| 469 |
logging.error(f"MTK Attempt {retry_attempts}: {e}")
|
| 470 |
time.sleep(1 * retry_attempts)
|
| 471 |
-
|
|
|
|
| 472 |
return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
|
| 473 |
|
| 474 |
class NTUAgent:
|
|
@@ -489,14 +510,21 @@ class NTUAgent:
|
|
| 489 |
```{user_log}
|
| 490 |
```
|
| 491 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
url = 'http://api.twllm.com:20002/v1/chat/completions'
|
| 494 |
-
message = f"{system_prompt}, 以下是我的冒險故事 ```{user_log}```"
|
| 495 |
-
logging.warning(f"NTU Generating response...")
|
| 496 |
-
logging.warning(f"NTU message: {message}")
|
| 497 |
data = {
|
| 498 |
"model": "yentinglin/Taiwan-LLM-13B-v2.0-chat",
|
| 499 |
-
"messages":
|
| 500 |
"temperature": 0.7,
|
| 501 |
"top_p": 1,
|
| 502 |
"n": 1,
|
|
@@ -527,11 +555,13 @@ class NTUAgent:
|
|
| 527 |
retry_attempts = 0
|
| 528 |
while retry_attempts < 5:
|
| 529 |
try:
|
|
|
|
| 530 |
response = requests.post(url, headers=headers, data=json.dumps(data)).json()
|
| 531 |
response_text = response["choices"][0]["message"]["content"]
|
| 532 |
|
| 533 |
chinese_converter = OpenCC("s2tw")
|
| 534 |
|
|
|
|
| 535 |
return chinese_converter.convert(response_text)
|
| 536 |
|
| 537 |
except Exception as e:
|
|
@@ -539,6 +569,7 @@ class NTUAgent:
|
|
| 539 |
logging.error(f"NTU Attempt {retry_attempts}: {e}")
|
| 540 |
time.sleep(1 * retry_attempts)
|
| 541 |
|
|
|
|
| 542 |
return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
|
| 543 |
|
| 544 |
class ImageProcessor:
|
|
|
|
| 55 |
self.google_agent = GoogleAgent()
|
| 56 |
self.mtk_agent = MTKAgent()
|
| 57 |
self.ntu_agent = NTUAgent()
|
| 58 |
+
self.openai_response_time = None
|
| 59 |
+
self.aws_response_time = None
|
| 60 |
+
self.google_response_time = None
|
| 61 |
+
self.mtk_response_time = None
|
| 62 |
+
self.ntu_response_time = None
|
| 63 |
self.agents_responses = {}
|
| 64 |
self.agent_list = [
|
| 65 |
self.openai_agent,
|
|
|
|
| 159 |
"paragraph_google": self.paragraph_google,
|
| 160 |
"paragraph_mtk": self.paragraph_mtk,
|
| 161 |
"paragraph_ntu": self.paragraph_ntu,
|
| 162 |
+
"openai_response_time": self.openai_response_time,
|
| 163 |
+
"aws_response_time": self.aws_response_time,
|
| 164 |
+
"google_response_time": self.google_response_time,
|
| 165 |
+
"mtk_response_time": self.mtk_response_time,
|
| 166 |
+
"ntu_response_time": self.ntu_response_time,
|
| 167 |
"player_certificate_url": self.player_certificate_url,
|
| 168 |
"created_at_date": datetime.now().date(),
|
| 169 |
}
|
|
|
|
| 260 |
|
| 261 |
retry_attempts = 0
|
| 262 |
while retry_attempts < 5:
|
| 263 |
+
start_time = time.time()
|
| 264 |
try:
|
| 265 |
response = client.chat.completions.create(
|
| 266 |
model="gpt-4-1106-preview",
|
|
|
|
| 271 |
presence_penalty=self.presence_penalty,
|
| 272 |
)
|
| 273 |
chinese_converter = OpenCC("s2tw")
|
| 274 |
+
self.openai_response_time = time.time() - start_time
|
| 275 |
return chinese_converter.convert(response.choices[0].message.content)
|
| 276 |
|
| 277 |
except Exception as e:
|
|
|
|
| 279 |
logging.error(f"OpenAI Attempt {retry_attempts}: {e}")
|
| 280 |
time.sleep(1 * retry_attempts)
|
| 281 |
|
| 282 |
+
self.openai_response_time = time.time() - start_time
|
| 283 |
return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
|
| 284 |
|
| 285 |
def get_background(self):
|
|
|
|
| 334 |
retry_attempts = 0
|
| 335 |
while retry_attempts < 5:
|
| 336 |
try:
|
| 337 |
+
start_time = time.time()
|
| 338 |
completion = client.completions.create(
|
| 339 |
model="anthropic.claude-v2",
|
| 340 |
max_tokens_to_sample=2048,
|
| 341 |
prompt=f"{anthropic_bedrock.HUMAN_PROMPT}{system_prompt},以下是我的故事紀錄```{user_log}``` {anthropic_bedrock.AI_PROMPT}",
|
| 342 |
)
|
| 343 |
chinese_converter = OpenCC("s2tw")
|
| 344 |
+
|
| 345 |
+
self.aws_response_time = time.time() - start_time
|
| 346 |
return chinese_converter.convert(completion.completion)
|
| 347 |
|
| 348 |
except Exception as e:
|
|
|
|
| 350 |
logging.error(f"AWS Attempt {retry_attempts}: {e}")
|
| 351 |
time.sleep(1 * retry_attempts)
|
| 352 |
|
| 353 |
+
self.aws_response_time = time.time() - start_time
|
| 354 |
return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
|
| 355 |
|
| 356 |
|
|
|
|
| 394 |
retry_attempts = 0
|
| 395 |
while retry_attempts < 5:
|
| 396 |
try:
|
| 397 |
+
start_time = time.time()
|
| 398 |
logging.info("Google Generating response...")
|
| 399 |
model_response = self.gemini_pro_model.generate_content(
|
| 400 |
f"{system_prompt}, 以下是我的冒險故事 ```{user_log}```"
|
| 401 |
)
|
| 402 |
|
| 403 |
chinese_converter = OpenCC("s2tw")
|
| 404 |
+
self.google_response_time = time.time() - start_time
|
| 405 |
return chinese_converter.convert(
|
| 406 |
model_response.candidates[0].content.parts[0].text
|
| 407 |
)
|
|
|
|
| 411 |
logging.error(f"Google Attempt {retry_attempts}: {e}")
|
| 412 |
time.sleep(1 * retry_attempts)
|
| 413 |
|
| 414 |
+
self.google_response_time = time.time() - start_time
|
| 415 |
return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
|
| 416 |
|
| 417 |
|
|
|
|
| 466 |
retry_attempts = 0
|
| 467 |
while retry_attempts < 5:
|
| 468 |
try:
|
| 469 |
+
start_time = time.time()
|
| 470 |
response = requests.post(
|
| 471 |
url, headers=headers, data=json.dumps(data)
|
| 472 |
).json()
|
|
|
|
| 478 |
extracted_content = "\n".join(matched_contents).strip()
|
| 479 |
|
| 480 |
chinese_converter = OpenCC("s2tw")
|
| 481 |
+
self.mtk_response_time = time.time() - start_time
|
| 482 |
if extracted_content:
|
| 483 |
return chinese_converter.convert(extracted_content)
|
| 484 |
else:
|
|
|
|
| 488 |
retry_attempts += 1
|
| 489 |
logging.error(f"MTK Attempt {retry_attempts}: {e}")
|
| 490 |
time.sleep(1 * retry_attempts)
|
| 491 |
+
|
| 492 |
+
self.mtk_response_time = time.time() - start_time
|
| 493 |
return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
|
| 494 |
|
| 495 |
class NTUAgent:
|
|
|
|
| 510 |
```{user_log}
|
| 511 |
```
|
| 512 |
"""
|
| 513 |
+
messages = [
|
| 514 |
+
{
|
| 515 |
+
"role": "system",
|
| 516 |
+
"content": f"{system_prompt}",
|
| 517 |
+
},
|
| 518 |
+
{
|
| 519 |
+
"role": "user",
|
| 520 |
+
"content": f"{user_log}",
|
| 521 |
+
},
|
| 522 |
+
]
|
| 523 |
|
| 524 |
url = 'http://api.twllm.com:20002/v1/chat/completions'
|
|
|
|
|
|
|
|
|
|
| 525 |
data = {
|
| 526 |
"model": "yentinglin/Taiwan-LLM-13B-v2.0-chat",
|
| 527 |
+
"messages": messages,
|
| 528 |
"temperature": 0.7,
|
| 529 |
"top_p": 1,
|
| 530 |
"n": 1,
|
|
|
|
| 555 |
retry_attempts = 0
|
| 556 |
while retry_attempts < 5:
|
| 557 |
try:
|
| 558 |
+
start_time = time.time()
|
| 559 |
response = requests.post(url, headers=headers, data=json.dumps(data)).json()
|
| 560 |
response_text = response["choices"][0]["message"]["content"]
|
| 561 |
|
| 562 |
chinese_converter = OpenCC("s2tw")
|
| 563 |
|
| 564 |
+
self.ntu_response_time = time.time() - start_time
|
| 565 |
return chinese_converter.convert(response_text)
|
| 566 |
|
| 567 |
except Exception as e:
|
|
|
|
| 569 |
logging.error(f"NTU Attempt {retry_attempts}: {e}")
|
| 570 |
time.sleep(1 * retry_attempts)
|
| 571 |
|
| 572 |
+
self.ntu_response_time = time.time() - start_time
|
| 573 |
return '星際夥伴短時間內寫了太多故事,需要休息一下,請稍後再試,或是選擇其他星際夥伴的故事。'
|
| 574 |
|
| 575 |
class ImageProcessor:
|