File size: 2,530 Bytes
888aba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import pickle
import benepar
import nltk
from nltk.tree import Tree

import os
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

nltk.data.path.append('data/nltk_data')

class Parser():
    def __init__(self):
        torch.set_default_device("cpu")
        self.parser = benepar.Parser("benepar_en3_large")
        self.parser.batch_size = 64
        self.parsed_eval_reviews_path = "data/parsed/parsed_reviews.pkl"
        self.parsed_toy_reviews_path = "data/parsed/parsed_toy_data_reviews.pkl"

    def subtree_set(self, tree: Tree):
        """
        Return a flat set of all subtrees as strings (hashable).
        """
        subs = set()

        def helper(t):
            # Convert each subtree to a string and add to the set
            subs.add(str(t))
            for child in t:
                if isinstance(child, Tree):
                    helper(child)

        helper(tree)
        return subs

    def parse_text(self, text):
        try:
            return self.parser.parse(text[:10000])  # truncate long reviews
        except Exception as e:
            print(f"Parse error: {e}")
            return None

    def parse_reviews(self, reviews: list, toy_data: bool) -> list[set]:
        parsed_reviews = []
        with ThreadPoolExecutor(max_workers=os.cpu_count()-1) as executor:
            for tree in tqdm(executor.map(self.parse_text, reviews), total=len(reviews)):
                if isinstance(tree, Tree):
                    parsed_reviews.append(self.subtree_set(tree))
                else:
                    parsed_reviews.append(set())  # fallback for parse errors

        # Save parsed reviews
        with open(self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path, "wb") as f:
            pickle.dump(parsed_reviews, f)

        return parsed_reviews

    def compute_syntactic_similarity(self, query_tree_subs: set, review_tree_subs: set) -> float:
        """
        Jaccard similarity between two sets of subtrees (strings, hashable)
        """
        intersect = query_tree_subs.intersection(review_tree_subs)
        union = query_tree_subs.union(review_tree_subs)
        if not union:
            return 0.0
        return len(intersect) / len(union)

    def load_parsed_reviews(self, toy_data: bool) -> list[set]:
        path = self.parsed_toy_reviews_path if toy_data else self.parsed_eval_reviews_path
        with open(path, "rb") as f:
            parsed_reviews = pickle.load(f)
        return parsed_reviews