team-149-project / utils /syntactic_similarity.py
knguyen471's picture
Upload 11 files
888aba6 verified
raw
history blame
2.53 kB
import torch
import pickle
import benepar
import nltk
from nltk.tree import Tree
import os
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
nltk.data.path.append('data/nltk_data')
class Parser():
def __init__(self):
torch.set_default_device("cpu")
self.parser = benepar.Parser("benepar_en3_large")
self.parser.batch_size = 64
self.parsed_eval_reviews_path = "data/parsed/parsed_reviews.pkl"
self.parsed_toy_reviews_path = "data/parsed/parsed_toy_data_reviews.pkl"
def subtree_set(self, tree: Tree):
"""
Return a flat set of all subtrees as strings (hashable).
"""
subs = set()
def helper(t):
# Convert each subtree to a string and add to the set
subs.add(str(t))
for child in t:
if isinstance(child, Tree):
helper(child)
helper(tree)
return subs
def parse_text(self, text):
try:
return self.parser.parse(text[:10000]) # truncate long reviews
except Exception as e:
print(f"Parse error: {e}")
return None
def parse_reviews(self, reviews: list, toy_data: bool) -> list[set]:
parsed_reviews = []
with ThreadPoolExecutor(max_workers=os.cpu_count()-1) as executor:
for tree in tqdm(executor.map(self.parse_text, reviews), total=len(reviews)):
if isinstance(tree, Tree):
parsed_reviews.append(self.subtree_set(tree))
else:
parsed_reviews.append(set()) # fallback for parse errors
# Save parsed reviews
with open(self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path, "wb") as f:
pickle.dump(parsed_reviews, f)
return parsed_reviews
def compute_syntactic_similarity(self, query_tree_subs: set, review_tree_subs: set) -> float:
"""
Jaccard similarity between two sets of subtrees (strings, hashable)
"""
intersect = query_tree_subs.intersection(review_tree_subs)
union = query_tree_subs.union(review_tree_subs)
if not union:
return 0.0
return len(intersect) / len(union)
def load_parsed_reviews(self, toy_data: bool) -> list[set]:
path = self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path
with open(path, "rb") as f:
parsed_reviews = pickle.load(f)
return parsed_reviews