File size: 7,002 Bytes
7a8b33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import json
import argparse
import multiprocessing as mp
from zsvision.zs_multiproc import starmap_with_kwargs
from typing import List, Dict
import numpy as np
from zsvision.zs_utils import BlockTimer
from llm_api_utils import (
    call_openai_with_exponetial_backoff,
    estimate_cost_of_text_generation_api_call,
    init_openai_with_api_key,
)
import random


class ClassifyClaims:
    def __init__(
        self,
        temperature=0,
        model="gpt-3.5-turbo",
        max_claims_per_api_call=10,
        processes=8,
        filter_str="",
        refresh=False,
    ):
        self.temperature = temperature
        self.model = model
        self.max_claims_per_api_call = max_claims_per_api_call
        self.processes = processes
        self.filter_str = filter_str
        self.refresh = refresh
        self.objective_claims_file = "objective_claims.txt"
        self.subjective_claims_file = "subjective_claims.txt"

    def parse_classification_label(self, text: str) -> str:
        raw = text.strip()
        if raw.endswith("[objective]"):
            label = "objective"
        elif raw.endswith("[subjective]"):
            label = "subjective"
        else:
            raise ValueError(f"Invalid label: {raw}")
        return label

    def read_file(self, file_name):
        with open(file_name, "r") as f:
            lines = []
            for line in f:
                parsed_line = line.strip()
                lines.append(parsed_line)
        return lines

    def create_few_shot_learning_prompt(self) -> str:
        objective_list = self.read_file(self.objective_claims_file)
        subjective_list = self.read_file(self.subjective_claims_file)
        merged_list = list(
            zip(objective_list, ["[objective]"] * len(objective_list))
        ) + list(zip(subjective_list, ["[subjective]"] * len(subjective_list)))

        # Randomizing the merged list with a specific seed
        seed = 1234
        random.seed(seed)
        random.shuffle(merged_list)
        prompt = "Claims:\n"
        for claim, _ in merged_list:
            prompt += claim + "\n"
        prompt += "\nClassifications:\n"
        for claim, classif in merged_list:
            prompt += claim + " " + classif + "\n"
        return prompt

    def classify_claim_batch(
        self,
        idx: int,
        total: int,
        claims_and_sources_batch: List[Dict[str, str]],
    ):
        print(
            f"Processing batch {idx+1} of {total} (containing {len(claims_and_sources_batch)} claims)"
        )

        claim_str = "\n".join([claim["claim"] for claim in claims_and_sources_batch])
        num_batch_claims = len(claims_and_sources_batch)
        few_shot = self.create_few_shot_learning_prompt()
        prompt = f"""\
Objective claims can be verified based on factual data (such as those that could be verified by \
referencing an encyclopedia), whereas subjective claims involve a personal interpretation of \
the data and are more open to debate. \
For each of the following claims given below the dashed horizontal line, classify them as \
[subjective] or [objective] by suffixing the claim with the appropriate label. OUTPUT ONLY the class, either subjective or objective for each claim!

Here are some examples:

{few_shot}
----------
Claims:
{claim_str}

Classifications:\
"""
        persona = "You are a careful research assistant who helps with fact-checking and editing informative articles."
        system_message = {"role": "system", "content": persona}
        user_message = {"role": "user", "content": prompt}
        messages = [system_message, user_message]

        with BlockTimer(f"Using OpenAI API to extract claims with {self.model}"):
            response = call_openai_with_exponetial_backoff(
                model=self.model,
                temperature=self.temperature,
                messages=messages,
            )

        cost = estimate_cost_of_text_generation_api_call(
            model=self.model, response=response, verbose=True
        )

        proposed_classified_claims = response.choices[0].message.content
        batch_classified_claims = proposed_classified_claims.split("\n")

        content = response.choices[0].message.content
        batch_classified_claims = content.split("\n")
        assert (
            len(batch_classified_claims) == num_batch_claims
        ), f"Expected {num_batch_claims} claims, but got {len(batch_classified_claims)}"
        print(f"Generated {len(batch_classified_claims)} claims (cost: {cost:.4f} USD)")

        claims_with_labels = []
        for claim_and_source, classified_claim in zip(
            claims_and_sources_batch, batch_classified_claims
        ):
            claim_label = self.parse_classification_label(classified_claim)
            claim_and_source["label"] = claim_label
            claims_with_labels.append(claim_and_source)
        return {"claims_with_labels": claims_with_labels, "cost": cost}

    def classify_claims(self, claims_and_sources):
        """
        Classify claims as being either subjective or objective, and write the results to a file.
        """
        init_openai_with_api_key()
        num_claims = len(claims_and_sources)

        # we limit the number of claims per api call (otherwise GPT-4 can choke)
        num_batches = int(np.ceil(num_claims / self.max_claims_per_api_call))
        claims_and_sources_batches = [
            batch.tolist() for batch in np.array_split(claims_and_sources, num_batches)
        ]

        kwarg_list = []
        for idx, claims_and_sources_batch in enumerate(claims_and_sources_batches):
            # remove newlines from the passage to avoid a confusing prompt format
            kwarg_list.append(
                {
                    "idx": idx,
                    "total": len(claims_and_sources_batches),
                    "claims_and_sources_batch": claims_and_sources_batch,
                }
            )

        if self.processes == 1:
            batch_results = []
            for kwargs in kwarg_list:
                batch_results.append(self.classify_claim_batch(**kwargs))
        else:  # multiprocess
            func = self.classify_claim_batch
            with mp.Pool(processes=self.processes) as pool:
                batch_results = starmap_with_kwargs(
                    pool=pool, func=func, kwargs_iter=kwarg_list
                )

        cost = sum([result["cost"] for result in batch_results])
        labelled_claims = []
        for batch in batch_results:
            labelled_claims.extend(batch["claims_with_labels"])

        print(f"Returning {len(labelled_claims)} claims (cost: {cost} USD)")
        return labelled_claims

    def filter_to_objective_claims(self, claims):
        """Filter claims to only those that are objective."""

        objective_claims = [claim for claim in claims if claim["label"] == "objective"]

        print(f"Returning {len(objective_claims)} objective claims")
        return objective_claims