File size: 4,219 Bytes
c7a6fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import requests
import os
import numpy as np
from itertools import combinations, product

# --- Configuration ---
DEV_SET_PATH = "/home/mshahidul/readctrl/data/new_exp/test_health_literacy_data.json"
FEW_SHOT_POOL_PATH = "/home/mshahidul/readctrl/data/new_exp/final_prompt_template_info.json"
LOCAL_API_URL = "http://172.16.34.29:8004/v1/chat/completions"
LOCAL_MODEL_NAME = "Qwen/Qwen3-30B-A3B-Instruct-2507"

# K-shot per label
K = 3 

# --- Logic ---

def build_fixed_prompt(selected_instances):
    """Builds a prompt from a specific provided list of instances."""
    instruction = (
        "You are an expert in health communication. Your task is to judge the health literacy level of the provided text.\n"
        "Classify the text into: low_health_literacy, intermediate_health_literacy, or proficient_health_literacy.\n\n"
        "### Examples:\n"
    )
    
    few_shot_blocks = ""
    for ex in selected_instances:
        few_shot_blocks += f"Target Text: \"{ex['gen_text']}\"\n"
        few_shot_blocks += f"Reasoning: {ex['reasoning']}\n"
        few_shot_blocks += f"Label: {ex['label']}\n"
        few_shot_blocks += "-" * 30 + "\n"
            
    return instruction + few_shot_blocks + "\n### Task:\nTarget Text: \"{input_text}\"\nReasoning:"

def get_prediction(prompt_template, input_text):
    final_prompt = prompt_template.format(input_text=input_text)
    payload = {"model": LOCAL_MODEL_NAME, "messages": [{"role": "user", "content": final_prompt}], "temperature": 0}
    try:
        response = requests.post(LOCAL_API_URL, json=payload, timeout=20)
        return response.json()['choices'][0]['message']['content'].strip()
    except: return "Error"

def parse_label(text):
    text = text.lower()
    if "low" in text: return "low_health_literacy"
    if "intermediate" in text: return "intermediate_health_literacy"
    if "proficient" in text: return "proficient_health_literacy"
    return "unknown"

# --- Execution ---

with open(DEV_SET_PATH, 'r') as f:
    dev_set = json.load(f)
with open(FEW_SHOT_POOL_PATH, 'r') as f:
    few_shot_pool = json.load(f)

# Group pool by labels
categorized = {}
for entry in few_shot_pool:
    categorized.setdefault(entry['label'], []).append(entry)

# 1. Generate all combinations of K items for EACH label
label_combos = []
target_labels = ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"]

for label in target_labels:
    pool = categorized.get(label, [])
    # Get all ways to pick K instances from this label's pool
    label_combos.append(list(combinations(pool, K)))

# 2. Get the Cartesian Product (Every combination of the combinations)
all_possible_prompts_configs = list(product(*label_combos))

print(f"Total unique prompt configurations to test: {len(all_possible_prompts_configs)}")

results_log = []

# 3. Iterate through every possible prompt configuration
for idx, config in enumerate(all_possible_prompts_configs):
    # Flatten the config (it's a tuple of tuples)
    flat_instances = [item for sublist in config for item in sublist]
    
    current_template = build_fixed_prompt(flat_instances)
    correct = 0
    
    # Run against Dev Set
    for case in dev_set:
        pred = parse_label(get_prediction(current_template, case['gen_text']))
        if pred == parse_label(case['label']):
            correct += 1
            
    accuracy = (correct / len(dev_set)) * 100
    
    # Store data
    config_metadata = [{"doc_id": inst['doc_id'], "label": inst['label']} for inst in flat_instances]
    results_log.append({
        "config_index": idx,
        "accuracy": accuracy,
        "instances": config_metadata
    })
    
    print(f"Config {idx+1}/{len(all_possible_prompts_configs)}: Accuracy = {accuracy:.2f}%")

# --- Save & Find Best ---
results_log.sort(key=lambda x: x['accuracy'], reverse=True)

output_path = "/home/mshahidul/readctrl/data/new_exp/exhaustive_3shot_results.json"
with open(output_path, 'w') as f:
    json.dump(results_log, f, indent=4)

best = results_log[0]
print("\n" + "="*50)
print(f"WINNING CONFIGURATION (Acc: {best['accuracy']:.2f}%)")
for inst in best['instances']:
    print(f"- {inst['label']}: {inst['doc_id']}")
print("="*50)