| import json |
| import requests |
| import random |
| import os |
| import numpy as np |
|
|
| |
| DEV_SET_PATH = "/home/mshahidul/readctrl/data/new_exp/test_health_literacy_data.json" |
| FEW_SHOT_POOL_PATH = "/home/mshahidul/readctrl/data/new_exp/final_prompt_template_info.json" |
| LOCAL_API_URL = "http://172.16.34.29:8004/v1/chat/completions" |
| LOCAL_MODEL_NAME = "Qwen/Qwen3-30B-A3B-Instruct-2507" |
|
|
| |
| SHOTS_TO_EVALUATE = [3] |
| NUM_TRIALS = 10 |
|
|
| |
|
|
| def build_random_prompt_with_tracking(few_shot_data, k_per_label): |
| """Samples k examples, builds prompt, and returns detailed usage info.""" |
| instruction = ( |
| "You are an expert in health communication. Your task is to judge the health literacy level of the provided text.\n" |
| "Classify the text into: low_health_literacy, intermediate_health_literacy, or proficient_health_literacy.\n\n" |
| ) |
| |
| categorized = {} |
| for entry in few_shot_data: |
| label = entry['label'] |
| categorized.setdefault(label, []).append(entry) |
|
|
| few_shot_blocks = "### Examples:\n" |
| labels = ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"] |
| |
| used_instances = [] |
| for label in labels: |
| pool = categorized.get(label, []) |
| selected = random.sample(pool, min(k_per_label, len(pool))) |
| |
| for ex in selected: |
| |
| used_instances.append({ |
| "doc_id": ex['doc_id'], |
| "label": ex['label'] |
| }) |
| |
| few_shot_blocks += f"Target Text: \"{ex['gen_text']}\"\n" |
| few_shot_blocks += f"Reasoning: {ex['reasoning']}\n" |
| few_shot_blocks += f"Label: {label}\n" |
| few_shot_blocks += "-" * 30 + "\n" |
| |
| prompt = instruction + few_shot_blocks + "\n### Task:\nTarget Text: \"{input_text}\"\nReasoning:" |
| return prompt, used_instances |
|
|
| def get_prediction(prompt_template, input_text): |
| final_prompt = prompt_template.format(input_text=input_text) |
| payload = {"model": LOCAL_MODEL_NAME, "messages": [{"role": "user", "content": final_prompt}], "temperature": 0} |
| try: |
| response = requests.post(LOCAL_API_URL, json=payload, timeout=30) |
| return response.json()['choices'][0]['message']['content'].strip() |
| except: return "Error" |
|
|
| def parse_label(text): |
| text = text.lower() |
| if "low" in text: return "low_health_literacy" |
| if "intermediate" in text: return "intermediate_health_literacy" |
| if "proficient" in text: return "proficient_health_literacy" |
| return "unknown" |
|
|
| |
|
|
| with open(DEV_SET_PATH, 'r') as f: |
| dev_set = json.load(f) |
| with open(FEW_SHOT_POOL_PATH, 'r') as f: |
| few_shot_pool = json.load(f) |
|
|
| shot_ids_in_pool = {item['doc_id'] for item in few_shot_pool} |
| clean_dev_set = [item for item in dev_set if item['doc_id'] not in shot_ids_in_pool] |
|
|
| all_exp_data = [] |
|
|
| for k in SHOTS_TO_EVALUATE: |
| print(f"\n>>> Running {k}-shot experiment ({NUM_TRIALS} trials)...") |
| trial_data = [] |
| |
| for t in range(NUM_TRIALS): |
| current_template, used_meta = build_random_prompt_with_tracking(few_shot_pool, k) |
| correct = 0 |
| |
| for case in clean_dev_set: |
| pred = parse_label(get_prediction(current_template, case['gen_text'])) |
| if pred == parse_label(case['label']): |
| correct += 1 |
| |
| acc = (correct / len(clean_dev_set)) * 100 |
| |
| trial_info = { |
| "trial_index": t + 1, |
| "accuracy": acc, |
| "used_instances": used_meta |
| } |
| trial_data.append(trial_info) |
| print(f" Trial {t+1}: {acc:.2f}% accuracy") |
|
|
| |
| accuracies = [td['accuracy'] for td in trial_data] |
| best_trial = max(trial_data, key=lambda x: x['accuracy']) |
| |
| all_exp_data.append({ |
| "shots_per_label": k, |
| "avg_accuracy": round(np.mean(accuracies), 2), |
| "std_dev": round(np.std(accuracies), 2), |
| "best_accuracy": best_trial['accuracy'], |
| "best_instances": best_trial['used_instances'], |
| "all_trials": trial_data |
| }) |
|
|
| |
| output_json = "/home/mshahidul/readctrl/data/new_exp/shot_experiment_detailed_tracking.json" |
| with open(output_json, 'w') as f: |
| json.dump(all_exp_data, f, indent=4) |
|
|
| print("\n" + "="*80) |
| print(f"{'Shots':<6} | {'Avg Acc':<10} | {'Best Acc':<10} | {'Best Sample Configuration (ID: Label)'}") |
| print("-" * 80) |
| for res in all_exp_data: |
| config_str = ", ".join([f"{inst['doc_id']}: {inst['label']}" for inst in res['best_instances']]) |
| print(f"{res['shots_per_label']:<6} | {res['avg_accuracy']:<8}% | {res['best_accuracy']:<8}% | {config_str}") |
| print("="*80) |