File size: 4,474 Bytes
030876e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import json
import re
import concurrent.futures
from openai import OpenAI

class MedicalClaimVerifier:
    def __init__(self):
        # Implementation remains similar, but with safer error handling
        api_file = "/home/mshahidul/api_new.json"
        with open(api_file, "r") as f:
            api_keys = json.load(f)
        self.api_key = api_keys["openai"]
        # Note: Ensure gpt-5-nano is actually available in your tier
        self.model_name = "gpt-5-nano" 
        self.client = OpenAI(api_key=self.api_key)

        self.thresholds = {
            "low": {"comp": 1.0, "cov": 0.3226},
            "intermediate": {"comp": 1.0, "cov": 0.4091},
            "proficient": {"comp": 1.0, "cov": 0.9347},
        }

    def get_prompt(self,context,claim):
        prompt = f"""
        CONTEXT:
        {context}

        CLAIM TO VERIFY:
        {claim}

        INSTRUCTION:
        Does the CONTEXT above provide enough evidence to support the CLAIM? 
        - Answer 'supported' if the claim is explicitly stated or logically followable.
        - Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info.

        Output only one word: 'supported' or 'not_supported'.
        """
        return prompt

    def check_support_api(self, prompt):
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}],
            )
            res = response.choices[0].message.content.strip().lower()
            return 1.0 if "supported" in res and "not_supported" not in res else 0.0
        except Exception:
            return 0.0

    def evaluate_level(self, gen_text, gold_subs, full_subs):
        if not gen_text or not gold_subs or not full_subs:
            return 0.0, 0.0
        
        # Combining calls to reduce overhead
        all_claims = gold_subs + full_subs
        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            results = list(executor.map(self.check_support_api, [self.get_prompt(gen_text, s) for s in all_claims]))
        
        comp_results = results[:len(gold_subs)]
        cov_results = results[len(gold_subs):]
        
        comp_score = sum(comp_results) / len(gold_subs)
        cov_score = sum(cov_results) / len(full_subs)
        return comp_score, cov_score

verifier = MedicalClaimVerifier()

def compute_score(data_source, solution_str, ground_truth, extra_info=None):  
    gold_subs = ground_truth.get('summary_subclaims', [])
    full_subs = ground_truth.get('fulltext_subclaims', [])
    
    if not gold_subs or not full_subs:
        return 0.0

    # 1. Parsing with fallback
    try:
        cleaned_str = solution_str.strip()
        if "```json" in cleaned_str:
            cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip()
        elif "```" in cleaned_str:
            cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip()
        data = json.loads(cleaned_str)
    except Exception:
        return -5.0 

    levels = ["low", "intermediate", "proficient"]
    scores = {}
    
    # 2. Score Calculation
    for lvl in levels:
        gen_text = data.get(f"{lvl}_health_literacy", "")
        if not gen_text:
            scores[lvl] = {"comp": 0.0, "cov": 0.0, "missing": True}
        else:
            comp, cov = verifier.evaluate_level(gen_text, gold_subs, full_subs)
            scores[lvl] = {"comp": comp, "cov": cov, "missing": False}

    # 3. Reward Shaping Logic
    total_reward = 0.0
    
    low_cov = scores["low"]["cov"]
    int_cov = scores["intermediate"]["cov"]
    pro_cov = scores["proficient"]["cov"]

    # Soft Hierarchy Check: Reward progression, penalize stagnation
    # Instead of -2.0 exit, we subtract if the order is wrong
    hierarchy_penalty = 0.0
    if not (low_cov <= int_cov <= pro_cov):
        hierarchy_penalty = -2.0 

    for lvl in levels:
        if scores[lvl]["missing"]:
            total_reward -= 1.0 # Penalty per missing field
            continue

        comp_s = scores[lvl]["comp"]
        cov_s = scores[lvl]["cov"]
        thresh = verifier.thresholds[lvl]

        # Continuous Reward: (Actual - Threshold) 
        # This tells the model "You're 10% away" vs "You failed"
        total_reward += (comp_s - thresh["comp"]) 
        total_reward += (cov_s - thresh["cov"])

    return total_reward + hierarchy_penalty