File size: 7,473 Bytes
c7a6fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import json
import re
import concurrent.futures
from openai import OpenAI

class MedicalClaimVerifier:
    def __init__(self):
        # OpenAI API configuration
        api_file = "/home/mshahidul/api_new.json"
        with open(api_file, "r") as f:
            api_keys = json.load(f)
        self.api_key = api_keys["openai"]
        self.model_name = "gpt-5-mini"
        self.client = OpenAI(api_key=self.api_key)

        # Literacy ranges (IQR after outlier removal) from paper summary
        # comp = completeness vs gold summary; cov = source_coverage vs full text
        self.threshold_ranges = {
            "low": {"comp": (0.9600, 1.0000), "cov": (0.1765, 0.3226)},
            "intermediate": {"comp": (0.9393, 1.0000), "cov": (0.1818, 0.4091)},
            "proficient": {"comp": (0.9231, 1.0000), "cov": (0.7725, 0.9347)},
        }

        # Minimum required information (upper bound of IQR)
        self.thresholds = {
            "low": {"comp": 1.0, "cov": 0.3226},
            "intermediate": {"comp": 1.0, "cov": 0.4091},
            "proficient": {"comp": 1.0, "cov": 0.9347},
        }

    def get_prompt(self,context,claim):
        prompt = f"""
        CONTEXT:
        {context}

        CLAIM TO VERIFY:
        {claim}

        INSTRUCTION:
        Does the CONTEXT above provide enough evidence to support the CLAIM? 
        - Answer 'supported' if the claim is explicitly stated or logically followable.
        - Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info.

        Output only one word: 'supported' or 'not_supported'.
        """
        return prompt

    def check_support_api(self, prompt):
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}],
            )
            res = response.choices[0].message.content.strip().lower()
            # print("API Response:", res)
            return 1.0 if "supported" in res and "not_supported" not in res else 0.0
        except Exception as e:
            print(f"API call error: {e}")
            return 0.0

    def evaluate_level(self, gen_text, gold_subs, full_subs, level_key):
        """Calculates scores for a single literacy level."""
        if not gen_text: return 0.0, 0.0

        # Run API calls in parallel to save time during RL
        try:
            with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
                # Completeness check (vs Gold Summary Subclaims)
                comp_prompts = [self.get_prompt(gen_text, s) for s in gold_subs]
                comp_results = list(executor.map(self.check_support_api, comp_prompts))
                comp_score = sum(comp_results) / len(comp_results) if comp_results else 0.0


                # Coverage check (vs Full Text Subclaims)
                cov_prompts = [self.get_prompt(gen_text, s) for s in full_subs]
                cov_results = list(executor.map(self.check_support_api, cov_prompts))
                cov_score = sum(cov_results) / len(cov_results) if cov_results else 0.0
            # print(f"Comp Score: {comp_score}, Cov Score: {cov_score} for {level_key}")
        except Exception as e:
            print(f"Parallel API call error: {e}")
            return 0.0, 0.0
    
        return comp_score, cov_score

    import json

    def get_reward_score(self, completion, gold_subs, full_subs):
        data = None
        
        # 1. Robust JSON Extraction
        try:
            # Clean potential markdown or whitespace
            text = completion[0]['content'].strip().replace("```json", "").replace("```", "").strip()                
            data = json.loads(text)
        except (json.JSONDecodeError, IndexError, ValueError) as e:
            print("JSON Parsing Error in Reward Calculation")
            # If all extraction attempts fail
            return -5.0

        # 2. Schema Validation
        levels = ["low", "intermediate", "proficient"]
        # Check if any required keys are missing
        if not all(f"{lvl}_health_literacy" in data for lvl in levels):
            return -2.0  # Slightly smaller penalty for partial formatting success

        # 3. Scoring Logic
        try:
            total_reward = 0.0
            pass_reward = 1.0
            fail_penalty = -1.0
            for lvl in levels:
                gen_text = data.get(f"{lvl}_health_literacy", "")
                
                # Skip scoring if text is empty
                if not gen_text:
                    total_reward += fail_penalty
                    continue
                    
                comp_score, cov_score = self.evaluate_level(gen_text, gold_subs, full_subs, lvl)
                
                # Apply Thresholds
                total_reward += pass_reward if comp_score >= self.thresholds[lvl]["comp"] else fail_penalty
                total_reward += pass_reward if cov_score >= self.thresholds[lvl]["cov"] else fail_penalty
                
            return total_reward
        except Exception:
            return -5.0


# 1. Ground Truth Subclaims (Extracted from a medical paper on Hypertension)
gold_summary_subclaims = [
    "Hypertension is defined as blood pressure above 140/90 mmHg.",
    "Lifestyle changes like low salt intake can reduce blood pressure.",
    "Diuretics are often the first line of pharmacological treatment."
]

full_text_subclaims = [
    "Hypertension is defined as blood pressure above 140/90 mmHg.",
    "Lifestyle changes like low salt intake can reduce blood pressure.",
    "Diuretics are often the first line of pharmacological treatment.",
    "The DASH diet emphasizes fruits, vegetables, and low-fat dairy.",
    "Chronic hypertension increases the risk of stroke and myocardial infarction.",
    "ACE inhibitors are contraindicated during pregnancy.",
    "Secondary hypertension can be caused by renal artery stenosis."
]

# 2. Mock Model Completion (The output being evaluated)
# This mimics the format your RL environment would pass to the reward function
mock_completion = [{
    'content': """
    {
        "low_health_literacy": "High blood pressure is when your blood is too strong for your veins. You should eat less salt to help stay healthy.",
        "intermediate_health_literacy": "Hypertension is blood pressure over 140/90. You can lower it by eating less salt and taking water pills (diuretics) if your doctor says so.",
        "proficient_health_literacy": "Hypertension (BP > 140/90 mmHg) is managed via lifestyle modifications like the DASH diet and salt restriction. Pharmacological interventions include diuretics as first-line therapy, though risks like stroke or heart attack persist if untreated. Secondary causes like renal artery stenosis should be screened, and ACE inhibitors must be avoided in pregnancy."
    }
    """
}]

# Initialize your verifier
verifier = MedicalClaimVerifier()

# Test the reward calculation
reward = verifier.get_reward_score(
    completion=mock_completion, 
    gold_subs=gold_summary_subclaims, 
    full_subs=full_text_subclaims
)

print(f"--- Evaluation Result ---")
print(f"Total Reward Score: {reward}")

# Logic Explanation:
# - Low: Likely fails 'comp' (missing 140/90 info), but might pass 'cov' (low threshold).
# - Intermediate: Likely passes 'comp' and 'cov'.
# - Proficient: Needs to cover almost all 7 subclaims to pass the 0.77 coverage threshold.