readctrl / code /classifier /few_shot_testing_3shots_all_comb.py
shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import json
import requests
import os
import numpy as np
from itertools import combinations, product
# --- Configuration ---
DEV_SET_PATH = "/home/mshahidul/readctrl/data/new_exp/test_health_literacy_data.json"
FEW_SHOT_POOL_PATH = "/home/mshahidul/readctrl/data/new_exp/final_prompt_template_info.json"
LOCAL_API_URL = "http://172.16.34.29:8004/v1/chat/completions"
LOCAL_MODEL_NAME = "Qwen/Qwen3-30B-A3B-Instruct-2507"
# K-shot per label
K = 3
# --- Logic ---
def build_fixed_prompt(selected_instances):
"""Builds a prompt from a specific provided list of instances."""
instruction = (
"You are an expert in health communication. Your task is to judge the health literacy level of the provided text.\n"
"Classify the text into: low_health_literacy, intermediate_health_literacy, or proficient_health_literacy.\n\n"
"### Examples:\n"
)
few_shot_blocks = ""
for ex in selected_instances:
few_shot_blocks += f"Target Text: \"{ex['gen_text']}\"\n"
few_shot_blocks += f"Reasoning: {ex['reasoning']}\n"
few_shot_blocks += f"Label: {ex['label']}\n"
few_shot_blocks += "-" * 30 + "\n"
return instruction + few_shot_blocks + "\n### Task:\nTarget Text: \"{input_text}\"\nReasoning:"
def get_prediction(prompt_template, input_text):
final_prompt = prompt_template.format(input_text=input_text)
payload = {"model": LOCAL_MODEL_NAME, "messages": [{"role": "user", "content": final_prompt}], "temperature": 0}
try:
response = requests.post(LOCAL_API_URL, json=payload, timeout=20)
return response.json()['choices'][0]['message']['content'].strip()
except: return "Error"
def parse_label(text):
text = text.lower()
if "low" in text: return "low_health_literacy"
if "intermediate" in text: return "intermediate_health_literacy"
if "proficient" in text: return "proficient_health_literacy"
return "unknown"
# --- Execution ---
with open(DEV_SET_PATH, 'r') as f:
dev_set = json.load(f)
with open(FEW_SHOT_POOL_PATH, 'r') as f:
few_shot_pool = json.load(f)
# Group pool by labels
categorized = {}
for entry in few_shot_pool:
categorized.setdefault(entry['label'], []).append(entry)
# 1. Generate all combinations of K items for EACH label
label_combos = []
target_labels = ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"]
for label in target_labels:
pool = categorized.get(label, [])
# Get all ways to pick K instances from this label's pool
label_combos.append(list(combinations(pool, K)))
# 2. Get the Cartesian Product (Every combination of the combinations)
all_possible_prompts_configs = list(product(*label_combos))
print(f"Total unique prompt configurations to test: {len(all_possible_prompts_configs)}")
results_log = []
# 3. Iterate through every possible prompt configuration
for idx, config in enumerate(all_possible_prompts_configs):
# Flatten the config (it's a tuple of tuples)
flat_instances = [item for sublist in config for item in sublist]
current_template = build_fixed_prompt(flat_instances)
correct = 0
# Run against Dev Set
for case in dev_set:
pred = parse_label(get_prediction(current_template, case['gen_text']))
if pred == parse_label(case['label']):
correct += 1
accuracy = (correct / len(dev_set)) * 100
# Store data
config_metadata = [{"doc_id": inst['doc_id'], "label": inst['label']} for inst in flat_instances]
results_log.append({
"config_index": idx,
"accuracy": accuracy,
"instances": config_metadata
})
print(f"Config {idx+1}/{len(all_possible_prompts_configs)}: Accuracy = {accuracy:.2f}%")
# --- Save & Find Best ---
results_log.sort(key=lambda x: x['accuracy'], reverse=True)
output_path = "/home/mshahidul/readctrl/data/new_exp/exhaustive_3shot_results.json"
with open(output_path, 'w') as f:
json.dump(results_log, f, indent=4)
best = results_log[0]
print("\n" + "="*50)
print(f"WINNING CONFIGURATION (Acc: {best['accuracy']:.2f}%)")
for inst in best['instances']:
print(f"- {inst['label']}: {inst['doc_id']}")
print("="*50)