| | import json |
| | import requests |
| | import os |
| | import numpy as np |
| | from itertools import combinations, product |
| |
|
| | |
| | DEV_SET_PATH = "/home/mshahidul/readctrl/data/new_exp/test_health_literacy_data.json" |
| | FEW_SHOT_POOL_PATH = "/home/mshahidul/readctrl/data/new_exp/final_prompt_template_info.json" |
| | LOCAL_API_URL = "http://172.16.34.29:8004/v1/chat/completions" |
| | LOCAL_MODEL_NAME = "Qwen/Qwen3-30B-A3B-Instruct-2507" |
| |
|
| | |
| | K = 3 |
| |
|
| | |
| |
|
| | def build_fixed_prompt(selected_instances): |
| | """Builds a prompt from a specific provided list of instances.""" |
| | instruction = ( |
| | "You are an expert in health communication. Your task is to judge the health literacy level of the provided text.\n" |
| | "Classify the text into: low_health_literacy, intermediate_health_literacy, or proficient_health_literacy.\n\n" |
| | "### Examples:\n" |
| | ) |
| | |
| | few_shot_blocks = "" |
| | for ex in selected_instances: |
| | few_shot_blocks += f"Target Text: \"{ex['gen_text']}\"\n" |
| | few_shot_blocks += f"Reasoning: {ex['reasoning']}\n" |
| | few_shot_blocks += f"Label: {ex['label']}\n" |
| | few_shot_blocks += "-" * 30 + "\n" |
| | |
| | return instruction + few_shot_blocks + "\n### Task:\nTarget Text: \"{input_text}\"\nReasoning:" |
| |
|
| | def get_prediction(prompt_template, input_text): |
| | final_prompt = prompt_template.format(input_text=input_text) |
| | payload = {"model": LOCAL_MODEL_NAME, "messages": [{"role": "user", "content": final_prompt}], "temperature": 0} |
| | try: |
| | response = requests.post(LOCAL_API_URL, json=payload, timeout=20) |
| | return response.json()['choices'][0]['message']['content'].strip() |
| | except: return "Error" |
| |
|
| | def parse_label(text): |
| | text = text.lower() |
| | if "low" in text: return "low_health_literacy" |
| | if "intermediate" in text: return "intermediate_health_literacy" |
| | if "proficient" in text: return "proficient_health_literacy" |
| | return "unknown" |
| |
|
| | |
| |
|
| | with open(DEV_SET_PATH, 'r') as f: |
| | dev_set = json.load(f) |
| | with open(FEW_SHOT_POOL_PATH, 'r') as f: |
| | few_shot_pool = json.load(f) |
| |
|
| | |
| | categorized = {} |
| | for entry in few_shot_pool: |
| | categorized.setdefault(entry['label'], []).append(entry) |
| |
|
| | |
| | label_combos = [] |
| | target_labels = ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"] |
| |
|
| | for label in target_labels: |
| | pool = categorized.get(label, []) |
| | |
| | label_combos.append(list(combinations(pool, K))) |
| |
|
| | |
| | all_possible_prompts_configs = list(product(*label_combos)) |
| |
|
| | print(f"Total unique prompt configurations to test: {len(all_possible_prompts_configs)}") |
| |
|
| | results_log = [] |
| |
|
| | |
| | for idx, config in enumerate(all_possible_prompts_configs): |
| | |
| | flat_instances = [item for sublist in config for item in sublist] |
| | |
| | current_template = build_fixed_prompt(flat_instances) |
| | correct = 0 |
| | |
| | |
| | for case in dev_set: |
| | pred = parse_label(get_prediction(current_template, case['gen_text'])) |
| | if pred == parse_label(case['label']): |
| | correct += 1 |
| | |
| | accuracy = (correct / len(dev_set)) * 100 |
| | |
| | |
| | config_metadata = [{"doc_id": inst['doc_id'], "label": inst['label']} for inst in flat_instances] |
| | results_log.append({ |
| | "config_index": idx, |
| | "accuracy": accuracy, |
| | "instances": config_metadata |
| | }) |
| | |
| | print(f"Config {idx+1}/{len(all_possible_prompts_configs)}: Accuracy = {accuracy:.2f}%") |
| |
|
| | |
| | results_log.sort(key=lambda x: x['accuracy'], reverse=True) |
| |
|
| | output_path = "/home/mshahidul/readctrl/data/new_exp/exhaustive_3shot_results.json" |
| | with open(output_path, 'w') as f: |
| | json.dump(results_log, f, indent=4) |
| |
|
| | best = results_log[0] |
| | print("\n" + "="*50) |
| | print(f"WINNING CONFIGURATION (Acc: {best['accuracy']:.2f}%)") |
| | for inst in best['instances']: |
| | print(f"- {inst['label']}: {inst['doc_id']}") |
| | print("="*50) |