Spaces:
Sleeping
Sleeping
File size: 2,530 Bytes
888aba6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
import pickle
import benepar
import nltk
from nltk.tree import Tree
import os
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
nltk.data.path.append('data/nltk_data')
class Parser():
def __init__(self):
torch.set_default_device("cpu")
self.parser = benepar.Parser("benepar_en3_large")
self.parser.batch_size = 64
self.parsed_eval_reviews_path = "data/parsed/parsed_reviews.pkl"
self.parsed_toy_reviews_path = "data/parsed/parsed_toy_data_reviews.pkl"
def subtree_set(self, tree: Tree):
"""
Return a flat set of all subtrees as strings (hashable).
"""
subs = set()
def helper(t):
# Convert each subtree to a string and add to the set
subs.add(str(t))
for child in t:
if isinstance(child, Tree):
helper(child)
helper(tree)
return subs
def parse_text(self, text):
try:
return self.parser.parse(text[:10000]) # truncate long reviews
except Exception as e:
print(f"Parse error: {e}")
return None
def parse_reviews(self, reviews: list, toy_data: bool) -> list[set]:
parsed_reviews = []
with ThreadPoolExecutor(max_workers=os.cpu_count()-1) as executor:
for tree in tqdm(executor.map(self.parse_text, reviews), total=len(reviews)):
if isinstance(tree, Tree):
parsed_reviews.append(self.subtree_set(tree))
else:
parsed_reviews.append(set()) # fallback for parse errors
# Save parsed reviews
with open(self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path, "wb") as f:
pickle.dump(parsed_reviews, f)
return parsed_reviews
def compute_syntactic_similarity(self, query_tree_subs: set, review_tree_subs: set) -> float:
"""
Jaccard similarity between two sets of subtrees (strings, hashable)
"""
intersect = query_tree_subs.intersection(review_tree_subs)
union = query_tree_subs.union(review_tree_subs)
if not union:
return 0.0
return len(intersect) / len(union)
def load_parsed_reviews(self, toy_data: bool) -> list[set]:
path = self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path
with open(path, "rb") as f:
parsed_reviews = pickle.load(f)
return parsed_reviews |