team-149-project / utils /syntactic_similarity.py
knguyen471's picture
Upload 11 files
888aba6 verified
import torch
import pickle
import benepar
import nltk
from nltk.tree import Tree
import os
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
nltk.data.path.append('data/nltk_data')
class Parser():
def __init__(self):
torch.set_default_device("cpu")
self.parser = benepar.Parser("benepar_en3_large")
self.parser.batch_size = 64
self.parsed_eval_reviews_path = "data/parsed/parsed_reviews.pkl"
self.parsed_toy_reviews_path = "data/parsed/parsed_toy_data_reviews.pkl"
def subtree_set(self, tree: Tree):
"""
Return a flat set of all subtrees as strings (hashable).
"""
subs = set()
def helper(t):
# Convert each subtree to a string and add to the set
subs.add(str(t))
for child in t:
if isinstance(child, Tree):
helper(child)
helper(tree)
return subs
def parse_text(self, text):
try:
return self.parser.parse(text[:10000]) # truncate long reviews
except Exception as e:
print(f"Parse error: {e}")
return None
def parse_reviews(self, reviews: list, toy_data: bool) -> list[set]:
parsed_reviews = []
with ThreadPoolExecutor(max_workers=os.cpu_count()-1) as executor:
for tree in tqdm(executor.map(self.parse_text, reviews), total=len(reviews)):
if isinstance(tree, Tree):
parsed_reviews.append(self.subtree_set(tree))
else:
parsed_reviews.append(set()) # fallback for parse errors
# Save parsed reviews
with open(self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path, "wb") as f:
pickle.dump(parsed_reviews, f)
return parsed_reviews
def compute_syntactic_similarity(self, query_tree_subs: set, review_tree_subs: set) -> float:
"""
Jaccard similarity between two sets of subtrees (strings, hashable)
"""
intersect = query_tree_subs.intersection(review_tree_subs)
union = query_tree_subs.union(review_tree_subs)
if not union:
return 0.0
return len(intersect) / len(union)
def load_parsed_reviews(self, toy_data: bool) -> list[set]:
path = self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path
with open(path, "rb") as f:
parsed_reviews = pickle.load(f)
return parsed_reviews