File size: 4,181 Bytes
c7a6fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import requests
import os

# --- Configuration ---
DEV_SET_PATH = "/home/mshahidul/readctrl/data/new_exp/test_health_literacy_data.json"
FEW_SHOT_SET_PATH = "/home/mshahidul/readctrl/data/new_exp/final_prompt_template_info.json" # Using the one with reasoning
LOCAL_API_URL = "http://172.16.34.29:8004/v1/chat/completions"
LOCAL_MODEL_NAME = "Qwen/Qwen3-30B-A3B-Instruct-2507"

# Define the range of few-shots per label you want to test
# e.g., [0, 1, 2, 3] will test 0-shot, 1-shot (3 total), 2-shot (6 total), etc.
SHOTS_TO_EVALUATE = [0, 1, 2, 3,4,5,6]

# --- Core Functions ---

def build_dynamic_prompt(few_shot_data, k_per_label):
    """Constructs a prompt with k examples per literacy category."""
    instruction = (
        "You are an expert in health communication. Your task is to judge the health literacy level of the provided text.\n"
        "Classify the text into: low_health_literacy, intermediate_health_literacy, or proficient_health_literacy.\n\n"
    )
    
    if k_per_label == 0:
        return instruction + "### Task:\nTarget Text: \"{input_text}\"\nReasoning:"

    # Organize few-shot data by label
    categorized = {}
    for entry in few_shot_data:
        label = entry['label']
        categorized.setdefault(label, []).append(entry)

    few_shot_blocks = "### Examples:\n"
    labels = ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"]
    
    for label in labels:
        examples = categorized.get(label, [])[:k_per_label]
        for ex in examples:
            few_shot_blocks += f"Target Text: \"{ex['gen_text']}\"\n"
            few_shot_blocks += f"Reasoning: {ex['reasoning']}\n"
            few_shot_blocks += f"Label: {label}\n"
            few_shot_blocks += "-" * 30 + "\n"
            
    return instruction + few_shot_blocks + "\n### Task:\nTarget Text: \"{input_text}\"\nReasoning:"

def get_prediction(prompt_template, input_text):
    """Sends the formatted prompt to the local LLM."""
    final_prompt = prompt_template.format(input_text=input_text)
    payload = {
        "model": LOCAL_MODEL_NAME,
        "messages": [{"role": "user", "content": final_prompt}],
        "temperature": 0 
    }
    try:
        response = requests.post(LOCAL_API_URL, json=payload, timeout=30)
        return response.json()['choices'][0]['message']['content'].strip()
    except Exception:
        return "Error"

def parse_label(text):
    """Normalizes LLM output to match dataset labels."""
    text = text.lower()
    if "low" in text: return "low_health_literacy"
    if "intermediate" in text: return "intermediate_health_literacy"
    if "proficient" in text: return "proficient_health_literacy"
    return "unknown"

# --- Main Execution ---

# 1. Load Data
with open(DEV_SET_PATH, 'r') as f:
    dev_set = json.load(f)
with open(FEW_SHOT_SET_PATH, 'r') as f:
    few_shot_pool = json.load(f)

# 2. Filter Dev Set
# Ensure no overlap between few-shot examples and dev set
shot_ids = {item['doc_id'] for item in few_shot_pool}
clean_dev_set = [item for item in dev_set if item['doc_id'] not in shot_ids]

results_summary = []

print(f"Starting Evaluation on {len(clean_dev_set)} samples...\n")

# 3. Loop through shot counts
for k in SHOTS_TO_EVALUATE:
    print(f"Evaluating {k}-shot per label (Total {k*3} examples)...")
    
    current_template = build_dynamic_prompt(few_shot_pool, k)
    correct = 0
    
    for case in clean_dev_set:
        raw_output = get_prediction(current_template, case['gen_text'])
        pred = parse_label(raw_output)
        actual = parse_label(case['label'])
        
        if pred == actual:
            correct += 1
            
    accuracy = (correct / len(clean_dev_set)) * 100
    results_summary.append({"shots_per_label": k, "accuracy": accuracy})
    print(f"-> Accuracy: {accuracy:.2f}%\n")

# --- Final Report ---
print("-" * 30)
print(f"{'Shots/Label':<15} | {'Accuracy':<10}")
print("-" * 30)
for res in results_summary:
    print(f"{res['shots_per_label']:<15} | {res['accuracy']:.2f}%")
with open("/home/mshahidul/readctrl/data/new_exp/few_shot_evaluation_summary.json", 'w') as f:
    json.dump(results_summary, f, indent=4)