| import json | |
| import os | |
| import random | |
| import logging | |
| import argparse | |
| import gradio as gr | |
| import datetime | |
| import time | |
| from rich.console import Console | |
| from rich.logging import RichHandler | |
| from rich.panel import Panel | |
| from tqdm import trange | |
| from LLM.offlinellm import OfflineLLM | |
| from LLM.apillm import APILLM | |
| from agent import Agent | |
| from frontEnd import frontEnd | |
| from main import CourtSimulation, parse_arguments | |
| console = Console() | |
| # 用于批量生成数据 | |
| def main(): | |
| print("In main..") | |
| """ | |
| 用于批量生成数据 | |
| - 本地开源 | |
| - qwen2.5-7b-instruct | |
| - qwen2.5-32b-instruct (base model) | |
| - meta-llama/Llama-3.1-8B-Instruct | |
| - THUDM/glm-4-9b-chat | |
| - 需要调用api的 | |
| - deepseek-r1 | |
| - claude-3-sonnet | |
| - gpt-4o-mini | |
| """ | |
| args = parse_arguments() | |
| simulation = CourtSimulation(args.init_config, args.stage_prompt, args.case, args.log_level, args.log_think,launch=False) | |
| # 从selected.json中读取需要测试的case | |
| # with open("selected.json", "r") as f: | |
| # selected = json.load(f) | |
| # # "危险驾驶罪":[0,1,2,3,4], | |
| # # "盗窃罪":[10,11,12,13,14], | |
| # # "故意伤害罪":[20,21,22,23,24], | |
| # simu_list_all = [] | |
| # for key, value in selected.items(): | |
| # simu_list_all.extend(value) | |
| # simu_list_all = list(set(simu_list_all)) | |
| # print(f"simu_list_all: {simu_list_all}") | |
| # 已经测试过的case | |
| simu_list_already = [0, 10, 20, 30, 40, 50, 60, 73, 85, 90] | |
| # 从resume_id的下一个开始测试 | |
| resume_id = 0 | |
| # simu_list = [x for x in simu_list_all if x not in simu_list_already and x > resume_id] | |
| # simu_list=[1,11,20,45,51] | |
| # simu_list=[45] | |
| # simu_list=[10] | |
| # simu_list=[50, 51, 60, 61, 70, 71, 80, 81] | |
| # simu_list=[90, 91, 100, 101, 110, 111, 120, 121] | |
| # simu_list=[130, 131, 140, 141, 150, 151, 160, 161] | |
| # simu_list=[170, 171, 180, 181, 190, 191, 200, 201] | |
| # simu_list=[210, 211, 220, 221, 230, 231, 240, 241] | |
| # for i in range(5): # 5组数字 | |
| # simu_list.append(i * 10) | |
| # simu_list.append(i * 10 + 1) | |
| # # simu_list=[0,1,2,3,4] | |
| # simu_list=[132,184,293,294,304] | |
| #[40,54,64,94,104, | |
| # simu_list=[81,84]# ,184,251,252,254,261,204,272] #[3,53,153,164,70,21,72,174, | |
| # simu_list=[283,73,374,370,94] # [181,310,64,373,104] # | |
| # simu_list=[332,134,351,201,164,350,194,104,353,300,354,302,83,41,131,172,382,234,231,220,204,290,193,202,34,62,200] | |
| # simu_list=[6,7] | |
| simu_list=[0] | |
| # for i in range(30,40): # 8*5=40 | |
| # for j in range(5): | |
| # # if i * 10+j<=343: | |
| # # continue | |
| # simu_list.append(i * 10+j) | |
| print(f"simu_list: {simu_list}") | |
| # model_list = ["qwen2.5-7b-instruct","Llama-3.1-8B-Instruct", "glm-4-9b-chat", "deepseek-r1", "claude-3-sonnet", "gpt-4o-mini"] | |
| # model_list = ["qwen2.5-7b-instruct"] | |
| # base_model= "qwen2.5-32b-instruct" | |
| # print(f"base model === {base_model}") | |
| # for model_test in model_list: | |
| # print(f"\n\nmodel to test === {model_test}\n\n") | |
| # for simu_id in simu_list: | |
| # print(f"----- simulation id = {simu_id} -----") | |
| # # # 法官 | |
| # simulation.start_simluation(simu_id, model_test, base_model, base_model, base_model, base_model) | |
| # # # 公诉人 | |
| # simulation.start_simluation(simu_id, base_model, base_model, model_test, base_model, base_model) | |
| # # # 辩护人 | |
| # simulation.start_simluation(simu_id, base_model, base_model, base_model, base_model, model_test) | |
| # print(f"------- end {simu_id} --------") | |
| # 测试qwen2.5-32b-instruct | |
| model = "deepseek-v3-250324" | |
| for simu_id in simu_list: | |
| # try: | |
| print(f"----- simulation id = {simu_id} -----") | |
| simulation.start_simluation(simu_id, model, model, model, model, model) | |
| print(f"------- end {simu_id} --------") | |
| # except Exception as e: | |
| # with open("failed_court.txt", "a", encoding="utf-8") as file: | |
| # file.write(str(simu_id)+" "+str(e)+"\n") | |
| # continue | |
| if __name__ == "__main__": | |
| main() |