|
|
--- |
|
|
license: apache-2.0 |
|
|
--- |
|
|
|
|
|
```python |
|
|
|
|
|
from gen import get_answer,get_state |
|
|
import torch |
|
|
|
|
|
|
|
|
def load_state(train_state_path, layer=32, n_embd=2560): |
|
|
train_state = torch.load(pth_file_path, map_location=torch.device('cpu')) |
|
|
state = [None] * (layer * 3) |
|
|
for i in range(layer): |
|
|
state[i*3+0]=torch.zeros(n_embd,).to(dtype=torch.bfloat16,device='cuda') |
|
|
state[i*3+1]=train_state[f'blocks.{i}.att.time_state'].to(dtype=torch.float,device='cuda') |
|
|
state[i*3+2]=torch.zeros(n_embd,).to(dtype=torch.bfloat16,device='cuda') |
|
|
return state |
|
|
|
|
|
def get_instruction(): |
|
|
"""返回固定的指令内容""" |
|
|
return "根据input中的input和entity_types,帮助用户找到文本中每种entity_types的实体,标明实体类型并且简单描述。然后给找到实体之间的关系,并且描述这段关系以及对关系强度打分。 避免使用诸如\"其他\"或\"未知\"的通用实体类型。 非常重要的是:不要生成冗余或重叠的实体类型和关系。用JSON格式输出。" |
|
|
|
|
|
def get_content(input_text): |
|
|
"""输入内容文本,返回格式化的content部分""" |
|
|
return f"'{{'input': '{input_text}'}}" |
|
|
|
|
|
def get_entity_types(entity_list): |
|
|
""" |
|
|
输入实体类型列表,返回格式化的entity_types部分 |
|
|
|
|
|
Args: |
|
|
entity_list: 可以是字符串列表 ['领域', '专家', '任务'] |
|
|
或者是字符串 '领域, 专家, 任务' |
|
|
""" |
|
|
if isinstance(entity_list, str): |
|
|
# 如果是字符串,按逗号分割 |
|
|
entity_list = [item.strip() for item in entity_list.split(',')] |
|
|
|
|
|
# 不带引号的格式(和原数据一致) |
|
|
entity_str = ', '.join(entity_list) |
|
|
return f"{{'entity_types': [{entity_str}]}}" |
|
|
|
|
|
def generate_prompt(content, entity_types): |
|
|
""" |
|
|
生成完整的prompt |
|
|
|
|
|
Args: |
|
|
content: 输入的文本内容 |
|
|
entity_types: 实体类型列表或字符串 |
|
|
|
|
|
Returns: |
|
|
完整的prompt字符串 |
|
|
""" |
|
|
instruction = get_instruction() |
|
|
content_part = get_content(content) |
|
|
entity_types_part = get_entity_types(entity_types) |
|
|
input_list_str = f'["content": {content_part}, "entity_types": {entity_types_part}]' |
|
|
# 按照指定格式拼接 |
|
|
prompt = ( |
|
|
f"{input_list_str}\n\n" |
|
|
f"User: Act as a specialized AI for Knowledge Graph construction. Your task is to extract entities and their relationships from the provided input, based on the given entity_types provided in above content.\nStructure your output as a single, valid JSON object with two top-level keys: entities and relationships.\nentities: A list of objects. Each object must have:\nentity: The exact name of the entity.\ndescription: A brief, context-based summary of the entity.\nrelationships: A list of objects. Each object must have:\nsource: The name of the source entity.\ntarget: The name of the target entity.\nrelationship: A concise description of their connection.\nCritical Rules:\nStrict Typing: Use only the provided entity types. Do not invent types or use generics like \"Other\".\nNo Redundancy: Do not create duplicate or reciprocal relationships (e.g., if A acquired B exists, do not add B was acquired by A).\nYour response must be only the JSON object.\n\n" |
|
|
f"Assistant:" |
|
|
) |
|
|
|
|
|
return prompt |
|
|
|
|
|
content1 = "根据我国的监狱法令,为了协助监狱囚犯改过自新和重新融入社会,监禁期至少四个星期的囚犯可在服刑至少14天后转入居家宵禁计划,在家服满剩余的刑期" |
|
|
entity_types1 = ["法律法规", "人物类别", "时间条件", "政策措施"] |
|
|
|
|
|
ctx = generate_prompt(content1, entity_types1) |
|
|
|
|
|
pth_file_path = "/home/rwkv/models/triplets1/rwkv-0.pth" |
|
|
|
|
|
|
|
|
tt_state = load_state(pth_file_path) |
|
|
|
|
|
print(ctx) |
|
|
res1 = get_answer(ctx,state=tt_state) |
|
|
print('train_state :',res1) |
|
|
|