|
|
""" |
|
|
input: rg |
|
|
output (fixed 100 candidates, for path-based reranking): |
|
|
{ |
|
|
"query": query, |
|
|
"pred_dict": {node_id: score}, |
|
|
"ans_ids": [], |
|
|
'paths': {node_id: [node_ids_path]} |
|
|
} |
|
|
|
|
|
""" |
|
|
import sys |
|
|
import os |
|
|
sys.path.append(os.path.dirname(os.path.abspath(os.getcwd()))) |
|
|
|
|
|
from utils import combine_dicts, parse_metapath, get_scorer, get_text_retriever, fix_length |
|
|
from models.model import ModelForSTaRKQA |
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
class Stru4Path(ModelForSTaRKQA): |
|
|
def __init__(self, dataset_name, text_retriever_name, scorer_name, skb, topk=100): |
|
|
super(Stru4Path, self).__init__(skb) |
|
|
self.dataset_name = dataset_name |
|
|
self.text_retriever = get_text_retriever(dataset_name, text_retriever_name, skb) |
|
|
self.scorer = get_scorer(dataset_name, scorer_name=scorer_name, skb=skb) |
|
|
|
|
|
self.topk = topk |
|
|
self.node_type_list = skb.node_type_lst() |
|
|
self.edge_type_list = skb.rel_type_lst() |
|
|
if self.dataset_name == "prime": |
|
|
self.tp_list = skb.get_tuples() |
|
|
self.target_type_list = skb.candidate_types |
|
|
else: |
|
|
self.tp_dict = {(tp[0], tp[-1]): tp[1] for tp in skb.get_tuples()} |
|
|
self.target_type_list = ['paper' if dataset_name == 'mag' else 'product'] |
|
|
|
|
|
self.skb = skb |
|
|
self.ini_k = 5 |
|
|
self.stru_count = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rg2routes(self, rg): |
|
|
""" |
|
|
input: rg: {"Metapath": "", "Restriction": {}} |
|
|
output: routes: [['paper', 'author', 'paper'], ['paper', 'paper']] |
|
|
""" |
|
|
|
|
|
metapath = rg["Metapath"] |
|
|
if isinstance(rg["Metapath"], list): |
|
|
routes = rg["Metapath"] |
|
|
elif isinstance(rg["Metapath"], str): |
|
|
routes = parse_metapath(metapath) |
|
|
else: |
|
|
return None |
|
|
|
|
|
return routes |
|
|
|
|
|
def check_valid(self, routes, rg): |
|
|
|
|
|
if not routes: |
|
|
|
|
|
return None |
|
|
|
|
|
if len(routes) == 1 and len(routes[0]) == 1: |
|
|
return 1 |
|
|
|
|
|
|
|
|
target_type_valid_routes = [ |
|
|
route for route in routes if route[-1] in self.target_type_list |
|
|
] |
|
|
if not target_type_valid_routes: |
|
|
return None |
|
|
|
|
|
|
|
|
type_valid_routes = [ |
|
|
route |
|
|
for route in target_type_valid_routes |
|
|
if all( |
|
|
node in self.node_type_list or node in self.edge_type_list |
|
|
for node in route |
|
|
) |
|
|
] |
|
|
if not type_valid_routes: |
|
|
return None |
|
|
|
|
|
|
|
|
relation_valid_routes = [] |
|
|
for route in type_valid_routes: |
|
|
if self.dataset_name == "prime": |
|
|
triplets = [ |
|
|
(route[i], route[i + 1], route[i + 2]) |
|
|
for i in range(0, len(route) - 2, 2) |
|
|
] |
|
|
|
|
|
if all(tp in self.tp_list for tp in triplets): |
|
|
relation_valid_routes.append(route) |
|
|
else: |
|
|
pairs = [(route[i], route[i + 1]) for i in range(len(route) - 1)] |
|
|
if all(tp in self.tp_dict.keys() for tp in pairs): |
|
|
relations = [self.tp_dict[tp] for tp in pairs] |
|
|
|
|
|
|
|
|
new_route = [] |
|
|
for i in range(len(relations)): |
|
|
new_route.append(pairs[i][0]) |
|
|
new_route.append(relations[i]) |
|
|
new_route.append(pairs[-1][-1]) |
|
|
|
|
|
|
|
|
relation_valid_routes.append(new_route) |
|
|
|
|
|
if not relation_valid_routes: |
|
|
return None |
|
|
|
|
|
return relation_valid_routes |
|
|
|
|
|
def get_candidates4route(self, query, q_id, route, restriction): |
|
|
|
|
|
|
|
|
ini_node_type = route[0] |
|
|
|
|
|
try: |
|
|
extra_restr = "".join(restriction[ini_node_type]) |
|
|
except: |
|
|
extra_restr = "" |
|
|
ini_dict = self.text_retriever.retrieve(query + " " + extra_restr, q_id=q_id, topk=self.ini_k, node_type=ini_node_type) |
|
|
current_node_ids = list(ini_dict.keys()) |
|
|
|
|
|
|
|
|
paths = {} |
|
|
for c_id in current_node_ids: |
|
|
paths[c_id] = [c_id] |
|
|
|
|
|
|
|
|
hops = len(route) |
|
|
|
|
|
for hop in range(0, hops-2, 2): |
|
|
new_paths = {} |
|
|
|
|
|
cur_node_type = route[hop] |
|
|
next_node_type = route[hop+2] |
|
|
edge_type = route[hop+1] |
|
|
next_node_ids = [] |
|
|
|
|
|
|
|
|
for node_id in current_node_ids: |
|
|
neighbor_ids = self.skb.get_neighbor_nodes(idx=node_id, edge_type=edge_type) |
|
|
next_node_ids.extend(neighbor_ids) |
|
|
|
|
|
|
|
|
for neighbor_id in neighbor_ids: |
|
|
new_paths[neighbor_id] = paths[node_id] + [neighbor_id] |
|
|
|
|
|
|
|
|
paths = new_paths |
|
|
|
|
|
current_node_ids = list(set(next_node_ids)) |
|
|
|
|
|
candidates = current_node_ids |
|
|
self.paths.append(paths) |
|
|
|
|
|
|
|
|
return candidates |
|
|
|
|
|
def merge_candidate_pools(self, non_empty_candidates_lists): |
|
|
|
|
|
|
|
|
|
|
|
if len(non_empty_candidates_lists) == 1: |
|
|
return set(non_empty_candidates_lists[0]) |
|
|
|
|
|
result = set(non_empty_candidates_lists[0]) |
|
|
for lst in non_empty_candidates_lists[1:]: |
|
|
result.intersection_update(lst) |
|
|
|
|
|
|
|
|
if len(result) == 0: |
|
|
result = set() |
|
|
for lst in non_empty_candidates_lists: |
|
|
result.update(lst) |
|
|
|
|
|
|
|
|
|
|
|
return list(result) |
|
|
|
|
|
def get_mor_candidates(self, query, q_id, valid_routes, restriction): |
|
|
|
|
|
|
|
|
candidates_pool = [] |
|
|
for route in valid_routes: |
|
|
if route[0] in restriction.keys() and len(restriction[route[0]]) > 0: |
|
|
candidates_pool.append(self.get_candidates4route(query, q_id, route, restriction)) |
|
|
|
|
|
non_empty_candidates_lists = [lst for lst in candidates_pool if lst] |
|
|
if not non_empty_candidates_lists: |
|
|
print(f"123, {non_empty_candidates_lists}") |
|
|
|
|
|
|
|
|
return {} |
|
|
|
|
|
|
|
|
|
|
|
candidates = self.merge_candidate_pools(candidates_pool) |
|
|
if not candidates: |
|
|
return {} |
|
|
|
|
|
|
|
|
|
|
|
pred_dict = dict(zip(candidates, [-1]*len(candidates))) |
|
|
|
|
|
|
|
|
return pred_dict |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, query, q_id, ans_ids, rg): |
|
|
|
|
|
self.paths = [] |
|
|
|
|
|
|
|
|
|
|
|
s_time = time.time() |
|
|
routes = self.rg2routes(rg) |
|
|
|
|
|
|
|
|
|
|
|
s_time = time.time() |
|
|
valid_routes = self.check_valid(routes, rg) |
|
|
|
|
|
|
|
|
if valid_routes is None: |
|
|
|
|
|
return { |
|
|
"query": query, |
|
|
"pred_dict": {}, |
|
|
"ans_ids": ans_ids, |
|
|
'paths': {}, |
|
|
'query_pattern': rg['Metapath'] |
|
|
} |
|
|
elif valid_routes == 1: |
|
|
print(f"1234: {valid_routes}") |
|
|
|
|
|
pred_dict = self.text_retriever.retrieve(query, q_id=q_id, topk=self.topk, node_type=f'{self.target_type_list[0]}') |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
if self.dataset_name == "prime": |
|
|
pass |
|
|
else: |
|
|
valid_routes = [route[-5:] for route in valid_routes] |
|
|
|
|
|
restriction = rg["Restriction"] |
|
|
pred_dict = self.get_mor_candidates(query, q_id, valid_routes, restriction) |
|
|
self.stru_count += 1 |
|
|
|
|
|
|
|
|
if self.paths: |
|
|
self.paths = combine_dicts(self.paths, pred_dict=pred_dict) |
|
|
|
|
|
else: |
|
|
self.paths = {} |
|
|
for node_id in pred_dict.keys(): |
|
|
self.paths[node_id] = [node_id] |
|
|
|
|
|
|
|
|
if not pred_dict: |
|
|
return { |
|
|
"query": query, |
|
|
"pred_dict": {}, |
|
|
"ans_ids": ans_ids, |
|
|
'paths': {}, |
|
|
'query_pattern': rg['Metapath'] |
|
|
} |
|
|
|
|
|
|
|
|
pred_dict = self.scorer.score(query, q_id, list(pred_dict.keys())) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(self.paths) != len(pred_dict): |
|
|
print(f"paths: {self.paths}") |
|
|
print(f"pred_dict: {pred_dict}") |
|
|
raise ValueError(f"Length mismatch between paths and pred_dict: {len(self.paths)}, {len(pred_dict)}") |
|
|
|
|
|
output = { |
|
|
"query": query, |
|
|
"pred_dict": pred_dict, |
|
|
"ans_ids": ans_ids, |
|
|
'paths': self.paths, |
|
|
'query_pattern': rg['Metapath'], |
|
|
'rg': rg |
|
|
} |
|
|
|
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|