File size: 4,833 Bytes
c7a6fe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | import json
import requests
import random
import os
import numpy as np
# --- Configuration ---
DEV_SET_PATH = "/home/mshahidul/readctrl/data/new_exp/test_health_literacy_data.json"
FEW_SHOT_POOL_PATH = "/home/mshahidul/readctrl/data/new_exp/final_prompt_template_info.json"
LOCAL_API_URL = "http://172.16.34.29:8004/v1/chat/completions"
LOCAL_MODEL_NAME = "Qwen/Qwen3-30B-A3B-Instruct-2507"
# EXPERIMENT SETTINGS
SHOTS_TO_EVALUATE = [3]
NUM_TRIALS = 10
# --- Logic ---
def build_random_prompt_with_tracking(few_shot_data, k_per_label):
"""Samples k examples, builds prompt, and returns detailed usage info."""
instruction = (
"You are an expert in health communication. Your task is to judge the health literacy level of the provided text.\n"
"Classify the text into: low_health_literacy, intermediate_health_literacy, or proficient_health_literacy.\n\n"
)
categorized = {}
for entry in few_shot_data:
label = entry['label']
categorized.setdefault(label, []).append(entry)
few_shot_blocks = "### Examples:\n"
labels = ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"]
used_instances = [] # Now tracking both ID and Label
for label in labels:
pool = categorized.get(label, [])
selected = random.sample(pool, min(k_per_label, len(pool)))
for ex in selected:
# Store ID and Label pair
used_instances.append({
"doc_id": ex['doc_id'],
"label": ex['label']
})
few_shot_blocks += f"Target Text: \"{ex['gen_text']}\"\n"
few_shot_blocks += f"Reasoning: {ex['reasoning']}\n"
few_shot_blocks += f"Label: {label}\n"
few_shot_blocks += "-" * 30 + "\n"
prompt = instruction + few_shot_blocks + "\n### Task:\nTarget Text: \"{input_text}\"\nReasoning:"
return prompt, used_instances
def get_prediction(prompt_template, input_text):
final_prompt = prompt_template.format(input_text=input_text)
payload = {"model": LOCAL_MODEL_NAME, "messages": [{"role": "user", "content": final_prompt}], "temperature": 0}
try:
response = requests.post(LOCAL_API_URL, json=payload, timeout=30)
return response.json()['choices'][0]['message']['content'].strip()
except: return "Error"
def parse_label(text):
text = text.lower()
if "low" in text: return "low_health_literacy"
if "intermediate" in text: return "intermediate_health_literacy"
if "proficient" in text: return "proficient_health_literacy"
return "unknown"
# --- Execution ---
with open(DEV_SET_PATH, 'r') as f:
dev_set = json.load(f)
with open(FEW_SHOT_POOL_PATH, 'r') as f:
few_shot_pool = json.load(f)
shot_ids_in_pool = {item['doc_id'] for item in few_shot_pool}
clean_dev_set = [item for item in dev_set if item['doc_id'] not in shot_ids_in_pool]
all_exp_data = []
for k in SHOTS_TO_EVALUATE:
print(f"\n>>> Running {k}-shot experiment ({NUM_TRIALS} trials)...")
trial_data = []
for t in range(NUM_TRIALS):
current_template, used_meta = build_random_prompt_with_tracking(few_shot_pool, k)
correct = 0
for case in clean_dev_set:
pred = parse_label(get_prediction(current_template, case['gen_text']))
if pred == parse_label(case['label']):
correct += 1
acc = (correct / len(clean_dev_set)) * 100
trial_info = {
"trial_index": t + 1,
"accuracy": acc,
"used_instances": used_meta # List of {"doc_id": ..., "label": ...}
}
trial_data.append(trial_info)
print(f" Trial {t+1}: {acc:.2f}% accuracy")
# Aggregating shots data
accuracies = [td['accuracy'] for td in trial_data]
best_trial = max(trial_data, key=lambda x: x['accuracy'])
all_exp_data.append({
"shots_per_label": k,
"avg_accuracy": round(np.mean(accuracies), 2),
"std_dev": round(np.std(accuracies), 2),
"best_accuracy": best_trial['accuracy'],
"best_instances": best_trial['used_instances'],
"all_trials": trial_data
})
# --- Save Detailed Results ---
output_json = "/home/mshahidul/readctrl/data/new_exp/shot_experiment_detailed_tracking.json"
with open(output_json, 'w') as f:
json.dump(all_exp_data, f, indent=4)
print("\n" + "="*80)
print(f"{'Shots':<6} | {'Avg Acc':<10} | {'Best Acc':<10} | {'Best Sample Configuration (ID: Label)'}")
print("-" * 80)
for res in all_exp_data:
config_str = ", ".join([f"{inst['doc_id']}: {inst['label']}" for inst in res['best_instances']])
print(f"{res['shots_per_label']:<6} | {res['avg_accuracy']:<8}% | {res['best_accuracy']:<8}% | {config_str}")
print("="*80) |