readctrl / code /RL_model /unsloth_rl /reward_mock.py
shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import os
import json
import re
import concurrent.futures
from openai import OpenAI
class MedicalClaimVerifier:
def __init__(self):
# Implementation remains similar, but with safer error handling
api_file = "/home/mshahidul/api_new.json"
with open(api_file, "r") as f:
api_keys = json.load(f)
self.api_key = api_keys["openai"]
# Note: Ensure gpt-5-nano is actually available in your tier
self.model_name = "gpt-5-nano"
self.client = OpenAI(api_key=self.api_key)
self.thresholds = {
"low": {"comp": 1.0, "cov": 0.3226},
"intermediate": {"comp": 1.0, "cov": 0.4091},
"proficient": {"comp": 1.0, "cov": 0.9347},
}
def get_prompt(self,context,claim):
prompt = f"""
CONTEXT:
{context}
CLAIM TO VERIFY:
{claim}
INSTRUCTION:
Does the CONTEXT above provide enough evidence to support the CLAIM?
- Answer 'supported' if the claim is explicitly stated or logically followable.
- Answer 'not_supported' if the claim is missing, contradicts the text, or requires outside info.
Output only one word: 'supported' or 'not_supported'.
"""
return prompt
def check_support_api(self, prompt):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
)
res = response.choices[0].message.content.strip().lower()
return 1.0 if "supported" in res and "not_supported" not in res else 0.0
except Exception:
return 0.0
def evaluate_level(self, gen_text, gold_subs, full_subs):
if not gen_text or not gold_subs or not full_subs:
return 0.0, 0.0
# Combining calls to reduce overhead
all_claims = gold_subs + full_subs
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
results = list(executor.map(self.check_support_api, [self.get_prompt(gen_text, s) for s in all_claims]))
comp_results = results[:len(gold_subs)]
cov_results = results[len(gold_subs):]
comp_score = sum(comp_results) / len(gold_subs)
cov_score = sum(cov_results) / len(full_subs)
return comp_score, cov_score
verifier = MedicalClaimVerifier()
def compute_score(data_source, solution_str, ground_truth, extra_info=None):
gold_subs = ground_truth.get('summary_subclaims', [])
full_subs = ground_truth.get('fulltext_subclaims', [])
if not gold_subs or not full_subs:
return 0.0
# 1. Parsing with fallback
try:
cleaned_str = solution_str.strip()
if "```json" in cleaned_str:
cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip()
elif "```" in cleaned_str:
cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip()
data = json.loads(cleaned_str)
except Exception:
return -5.0
levels = ["low", "intermediate", "proficient"]
scores = {}
# 2. Score Calculation
for lvl in levels:
gen_text = data.get(f"{lvl}_health_literacy", "")
if not gen_text:
scores[lvl] = {"comp": 0.0, "cov": 0.0, "missing": True}
else:
comp, cov = verifier.evaluate_level(gen_text, gold_subs, full_subs)
scores[lvl] = {"comp": comp, "cov": cov, "missing": False}
# 3. Reward Shaping Logic
total_reward = 0.0
low_cov = scores["low"]["cov"]
int_cov = scores["intermediate"]["cov"]
pro_cov = scores["proficient"]["cov"]
# Soft Hierarchy Check: Reward progression, penalize stagnation
# Instead of -2.0 exit, we subtract if the order is wrong
hierarchy_penalty = 0.0
if not (low_cov <= int_cov <= pro_cov):
hierarchy_penalty = -2.0
for lvl in levels:
if scores[lvl]["missing"]:
total_reward -= 1.0 # Penalty per missing field
continue
comp_s = scores[lvl]["comp"]
cov_s = scores[lvl]["cov"]
thresh = verifier.thresholds[lvl]
# Continuous Reward: (Actual - Threshold)
# This tells the model "You're 10% away" vs "You failed"
total_reward += (comp_s - thresh["comp"])
total_reward += (cov_s - thresh["cov"])
return total_reward + hierarchy_penalty