File size: 4,833 Bytes
c7a6fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import json
import requests
import random
import os
import numpy as np

# --- Configuration ---
DEV_SET_PATH = "/home/mshahidul/readctrl/data/new_exp/test_health_literacy_data.json"
FEW_SHOT_POOL_PATH = "/home/mshahidul/readctrl/data/new_exp/final_prompt_template_info.json"
LOCAL_API_URL = "http://172.16.34.29:8004/v1/chat/completions"
LOCAL_MODEL_NAME = "Qwen/Qwen3-30B-A3B-Instruct-2507"

# EXPERIMENT SETTINGS
SHOTS_TO_EVALUATE = [3] 
NUM_TRIALS = 10

# --- Logic ---

def build_random_prompt_with_tracking(few_shot_data, k_per_label):
    """Samples k examples, builds prompt, and returns detailed usage info."""
    instruction = (
        "You are an expert in health communication. Your task is to judge the health literacy level of the provided text.\n"
        "Classify the text into: low_health_literacy, intermediate_health_literacy, or proficient_health_literacy.\n\n"
    )
    
    categorized = {}
    for entry in few_shot_data:
        label = entry['label']
        categorized.setdefault(label, []).append(entry)

    few_shot_blocks = "### Examples:\n"
    labels = ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"]
    
    used_instances = [] # Now tracking both ID and Label
    for label in labels:
        pool = categorized.get(label, [])
        selected = random.sample(pool, min(k_per_label, len(pool)))
        
        for ex in selected:
            # Store ID and Label pair
            used_instances.append({
                "doc_id": ex['doc_id'],
                "label": ex['label']
            })
            
            few_shot_blocks += f"Target Text: \"{ex['gen_text']}\"\n"
            few_shot_blocks += f"Reasoning: {ex['reasoning']}\n"
            few_shot_blocks += f"Label: {label}\n"
            few_shot_blocks += "-" * 30 + "\n"
            
    prompt = instruction + few_shot_blocks + "\n### Task:\nTarget Text: \"{input_text}\"\nReasoning:"
    return prompt, used_instances

def get_prediction(prompt_template, input_text):
    final_prompt = prompt_template.format(input_text=input_text)
    payload = {"model": LOCAL_MODEL_NAME, "messages": [{"role": "user", "content": final_prompt}], "temperature": 0}
    try:
        response = requests.post(LOCAL_API_URL, json=payload, timeout=30)
        return response.json()['choices'][0]['message']['content'].strip()
    except: return "Error"

def parse_label(text):
    text = text.lower()
    if "low" in text: return "low_health_literacy"
    if "intermediate" in text: return "intermediate_health_literacy"
    if "proficient" in text: return "proficient_health_literacy"
    return "unknown"

# --- Execution ---

with open(DEV_SET_PATH, 'r') as f:
    dev_set = json.load(f)
with open(FEW_SHOT_POOL_PATH, 'r') as f:
    few_shot_pool = json.load(f)

shot_ids_in_pool = {item['doc_id'] for item in few_shot_pool}
clean_dev_set = [item for item in dev_set if item['doc_id'] not in shot_ids_in_pool]

all_exp_data = []

for k in SHOTS_TO_EVALUATE:
    print(f"\n>>> Running {k}-shot experiment ({NUM_TRIALS} trials)...")
    trial_data = []
    
    for t in range(NUM_TRIALS):
        current_template, used_meta = build_random_prompt_with_tracking(few_shot_pool, k)
        correct = 0
        
        for case in clean_dev_set:
            pred = parse_label(get_prediction(current_template, case['gen_text']))
            if pred == parse_label(case['label']):
                correct += 1
        
        acc = (correct / len(clean_dev_set)) * 100
        
        trial_info = {
            "trial_index": t + 1,
            "accuracy": acc,
            "used_instances": used_meta  # List of {"doc_id": ..., "label": ...}
        }
        trial_data.append(trial_info)
        print(f"   Trial {t+1}: {acc:.2f}% accuracy")

    # Aggregating shots data
    accuracies = [td['accuracy'] for td in trial_data]
    best_trial = max(trial_data, key=lambda x: x['accuracy'])
    
    all_exp_data.append({
        "shots_per_label": k,
        "avg_accuracy": round(np.mean(accuracies), 2),
        "std_dev": round(np.std(accuracies), 2),
        "best_accuracy": best_trial['accuracy'],
        "best_instances": best_trial['used_instances'],
        "all_trials": trial_data
    })

# --- Save Detailed Results ---
output_json = "/home/mshahidul/readctrl/data/new_exp/shot_experiment_detailed_tracking.json"
with open(output_json, 'w') as f:
    json.dump(all_exp_data, f, indent=4)

print("\n" + "="*80)
print(f"{'Shots':<6} | {'Avg Acc':<10} | {'Best Acc':<10} | {'Best Sample Configuration (ID: Label)'}")
print("-" * 80)
for res in all_exp_data:
    config_str = ", ".join([f"{inst['doc_id']}: {inst['label']}" for inst in res['best_instances']])
    print(f"{res['shots_per_label']:<6} | {res['avg_accuracy']:<8}% | {res['best_accuracy']:<8}% | {config_str}")
print("="*80)