|
|
|
|
|
import pandas as pd
|
|
|
import json
|
|
|
|
|
|
|
|
|
with open('F:/GeoLLM/data/train_triples.json', 'r', encoding='utf-8') as f:
|
|
|
data = json.load(f)
|
|
|
|
|
|
df = pd.DataFrame(data)
|
|
|
|
|
|
text = df['text']
|
|
|
label = df['triple_list']
|
|
|
|
|
|
|
|
|
print(len(text))
|
|
|
from response_to_json import parse_llm_response, save_to_json, save_raw_response
|
|
|
from LLM import zero_shot
|
|
|
from prompt_generate import generate_prompt_with_examples as generate_prompt
|
|
|
from prompt_generate import generate_prompt_with_best_matches as generate_prompt_b
|
|
|
|
|
|
model_series = 'gpt'
|
|
|
|
|
|
model_name='deepseek-ai/DeepSeek-R1'
|
|
|
|
|
|
prompt = '''
|
|
|
你是一名专业经验丰富的工程地质领域专家,你的任务是从给定的输入文本中提取"实体-关系-实体"三元组。关系类型包括24种:"出露于"、"位于"、"整合接触"、"不整合接触"、"假整合接触"、"断层接触"、"分布形态"、"大地构造位置"、"地层区划"、"出露地层"、"岩性"、"厚度"、"面积"、"坐标"、"长度"、"含有"、"所属年代"、"行政区划"、"发育"、"古生物"、"海拔"、"属于"、"吞噬"、"侵入"。提取过程请按照以下规范:
|
|
|
1. 输出格式:
|
|
|
严格遵循JSON数组,无额外文本,每个元素包含:
|
|
|
[
|
|
|
{
|
|
|
"entity1": "实体1",
|
|
|
"relation": "关系",
|
|
|
"entity2": "实体2"
|
|
|
}
|
|
|
]
|
|
|
2. 复杂关系处理:
|
|
|
- 若同一实体参与多个关系,需分别列出不同三元组
|
|
|
'''
|
|
|
j=0
|
|
|
q=0
|
|
|
|
|
|
|
|
|
json_path = 'F:/GeoLLM/output/two_shot/'+model_name+'.json'
|
|
|
j=len(json.load(open(json_path,'r',encoding='utf-8')))
|
|
|
q=len(json.load(open('F:/GeoLLM/output/two_shot_raw/'+model_name+'.json','r',encoding='utf-8')))
|
|
|
print(j)
|
|
|
|
|
|
if q==j:
|
|
|
for i in range(j,500):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt_string = generate_prompt(text, label, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response = zero_shot(model_series, model_name, prompt+'\n'+'以下是地质描述文本和三元组提取样例'+'\n'+prompt_string+'\n'+'请根据样例提取三元组'+'\n'+text[i])
|
|
|
|
|
|
print(i)
|
|
|
|
|
|
formatted_triples = parse_llm_response(response)
|
|
|
|
|
|
|
|
|
save_to_json(text[i], formatted_triples, model_series=model_name, output_dir='F:/GeoLLM/output/two_shot/')
|
|
|
|
|
|
save_raw_response(response, text[i], model_series=model_name, output_dir='F:/GeoLLM/output/two_shot_raw/')
|
|
|
|
|
|
else:
|
|
|
print('q!=j') |