import torch import pickle import benepar import nltk from nltk.tree import Tree import os from tqdm import tqdm from concurrent.futures import ThreadPoolExecutor nltk.data.path.append('data/nltk_data') class Parser(): def __init__(self): torch.set_default_device("cpu") self.parser = benepar.Parser("benepar_en3_large") self.parser.batch_size = 64 self.parsed_eval_reviews_path = "data/parsed/parsed_reviews.pkl" self.parsed_toy_reviews_path = "data/parsed/parsed_toy_data_reviews.pkl" def subtree_set(self, tree: Tree): """ Return a flat set of all subtrees as strings (hashable). """ subs = set() def helper(t): # Convert each subtree to a string and add to the set subs.add(str(t)) for child in t: if isinstance(child, Tree): helper(child) helper(tree) return subs def parse_text(self, text): try: return self.parser.parse(text[:10000]) # truncate long reviews except Exception as e: print(f"Parse error: {e}") return None def parse_reviews(self, reviews: list, toy_data: bool) -> list[set]: parsed_reviews = [] with ThreadPoolExecutor(max_workers=os.cpu_count()-1) as executor: for tree in tqdm(executor.map(self.parse_text, reviews), total=len(reviews)): if isinstance(tree, Tree): parsed_reviews.append(self.subtree_set(tree)) else: parsed_reviews.append(set()) # fallback for parse errors # Save parsed reviews with open(self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path, "wb") as f: pickle.dump(parsed_reviews, f) return parsed_reviews def compute_syntactic_similarity(self, query_tree_subs: set, review_tree_subs: set) -> float: """ Jaccard similarity between two sets of subtrees (strings, hashable) """ intersect = query_tree_subs.intersection(review_tree_subs) union = query_tree_subs.union(review_tree_subs) if not union: return 0.0 return len(intersect) / len(union) def load_parsed_reviews(self, toy_data: bool) -> list[set]: path = self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path with open(path, "rb") as f: parsed_reviews = pickle.load(f) return parsed_reviews