Spaces:
Running
Running
| import torch | |
| import pickle | |
| import benepar | |
| import nltk | |
| from nltk.tree import Tree | |
| import os | |
| from tqdm import tqdm | |
| from concurrent.futures import ThreadPoolExecutor | |
| nltk.data.path.append('data/nltk_data') | |
| class Parser(): | |
| def __init__(self): | |
| torch.set_default_device("cpu") | |
| self.parser = benepar.Parser("benepar_en3_large") | |
| self.parser.batch_size = 64 | |
| self.parsed_eval_reviews_path = "data/parsed/parsed_reviews.pkl" | |
| self.parsed_toy_reviews_path = "data/parsed/parsed_toy_data_reviews.pkl" | |
| def subtree_set(self, tree: Tree): | |
| """ | |
| Return a flat set of all subtrees as strings (hashable). | |
| """ | |
| subs = set() | |
| def helper(t): | |
| # Convert each subtree to a string and add to the set | |
| subs.add(str(t)) | |
| for child in t: | |
| if isinstance(child, Tree): | |
| helper(child) | |
| helper(tree) | |
| return subs | |
| def parse_text(self, text): | |
| try: | |
| return self.parser.parse(text[:10000]) # truncate long reviews | |
| except Exception as e: | |
| print(f"Parse error: {e}") | |
| return None | |
| def parse_reviews(self, reviews: list, toy_data: bool) -> list[set]: | |
| parsed_reviews = [] | |
| with ThreadPoolExecutor(max_workers=os.cpu_count()-1) as executor: | |
| for tree in tqdm(executor.map(self.parse_text, reviews), total=len(reviews)): | |
| if isinstance(tree, Tree): | |
| parsed_reviews.append(self.subtree_set(tree)) | |
| else: | |
| parsed_reviews.append(set()) # fallback for parse errors | |
| # Save parsed reviews | |
| with open(self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path, "wb") as f: | |
| pickle.dump(parsed_reviews, f) | |
| return parsed_reviews | |
| def compute_syntactic_similarity(self, query_tree_subs: set, review_tree_subs: set) -> float: | |
| """ | |
| Jaccard similarity between two sets of subtrees (strings, hashable) | |
| """ | |
| intersect = query_tree_subs.intersection(review_tree_subs) | |
| union = query_tree_subs.union(review_tree_subs) | |
| if not union: | |
| return 0.0 | |
| return len(intersect) / len(union) | |
| def load_parsed_reviews(self, toy_data: bool) -> list[set]: | |
| path = self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path | |
| with open(path, "rb") as f: | |
| parsed_reviews = pickle.load(f) | |
| return parsed_reviews |