GeoLLM / prompt_generate.py
Pengfa Li
Upload folder using huggingface_hub
badcf3c verified
import random
import json
from difflib import SequenceMatcher
import pandas as pd
def generate_prompt_with_examples(text, label, n, start_index=500, end_index=1000):
"""
生成带有n个样例的提示。
Args:
text: 包含文本描述的 Series。
label: 包含三元组列表的 Series。
n: 要提取的样例数量。
start_index: 样例的起始索引。
end_index: 样例的结束索引。
Returns:
一个字符串,包含n个样例的提示,格式为text_prompt[i]+'\\n'+triple_prompt[i]。
如果n大于可用样例数量,则返回所有可用样例。
"""
text_len = len(text)
label_len = len(label)
end_index = min(end_index, text_len, label_len)
if start_index >= end_index:
return "起始索引大于或等于结束索引,无法生成样例。"
available_examples = end_index - start_index
n = min(n, available_examples)
prompt = ""
# 随机选择n个不同的索引
random_indices = random.sample(range(start_index, end_index), n)
for i in random_indices:
text_prompt = text.iloc[i]
triple_prompt = label.iloc[i]
prompt += text_prompt + '\n' + str(triple_prompt) + '\n'
return prompt
def generate_prompt_with_best_matches(text_series, label_series, query_text, n=3, start_index=500, end_index=1000):
"""
基于相似度匹配生成渐进式提示
参数:
text: 文本数据Series
label: 三元组标签数据Series
query_text: 需要生成提示的查询文本
n: 最大样例数量
start_index: 候选样例起始索引
end_index: 候选样例结束索引
返回:
渐进式提示字符串 (1个样例、2个样例、3个样例...)
"""
query_text_path = './data/text_retrieval_results.json'
with open(query_text_path, 'r', encoding='utf-8') as f:
query_text_data = json.load(f)
# 根据输入的query_text,在query_text_data中找到对应的query_text索引的matched_texts,基于输入n选择提取matched_texts中context的个数
for item in query_text_data:
if item['query_text'] == query_text:
matched_texts = item['matched_texts']
break
matched_texts = matched_texts[:n]
# 将matched_texts中的context提取出来
context_list = [text['context'] for text in matched_texts]
# 基于context_list中文本匹配输入text和label作为提示
prompt = ""
# 遍历每个上下文
for context in context_list:
# 遍历整个文本序列
for idx, text_item in text_series.items():
if context in text_item:
# 获取对应的标签
position = text_series.index.get_loc(idx)
label = label_series.iloc[position]
# print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
# print(text_item)
# print(label)
prompt += f"{text_item}\n{label}\n\n"
break
return prompt.strip()
if __name__ == "__main__":
# 读取训练数据
train_data = pd.read_json('./data/train_triples.json')
text = train_data['text']
label = train_data['triple_list']
query_text_path = './data/text_retrieval_results.json'
with open(query_text_path, 'r', encoding='utf-8') as f:
query_text_data = json.load(f)
query_text = query_text_data[3]['query_text']
print(query_text)
print("--------------------------------")
prompt = generate_prompt_with_best_matches(text, label, query_text, n=3, start_index=500, end_index=1000)
print(prompt)