|
|
import pandas as pd |
|
|
import argparse |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
import random |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_to_natural_phrase(val): |
|
|
"""Convert underscore-separated tokens to natural phrases.""" |
|
|
if isinstance(val, str) and "_" in val: |
|
|
val = val.replace("_", " ") |
|
|
return val |
|
|
|
|
|
|
|
|
def generate_answer(tokenizer, model, question, correct_value, device, mode="mcq"): |
|
|
"""Generate a natural language answer using a text-only LLM. |
|
|
|
|
|
mode: "mcq" (default) uses the original MCQ-oriented prompt. |
|
|
"open_text" uses a direct rewrite prompt for provided question/answer pairs. |
|
|
""" |
|
|
correct_value = convert_to_natural_phrase(correct_value) |
|
|
|
|
|
if mode == "open_text": |
|
|
system_preamble = ( |
|
|
"You convert (Question, short Answer) into EXACTLY ONE natural English sentence that answers the Question.\n\n" |
|
|
"HARD RULES:\n" |
|
|
"- Output exactly ONE sentence. No newlines, no bullet points, no labels, no quotes.\n" |
|
|
"- Use ONLY the provided Answer content as the factual answer; do not add any new facts.\n" |
|
|
"- Be concise and direct.\n" |
|
|
"- Do NOT include any numbers unless the question is a COUNT question.\n" |
|
|
"- Vary phrasing strongly across items; avoid repeating the same structure.\n\n" |
|
|
"VARIABILITY REQUIREMENT (IMPORTANT):\n" |
|
|
"- For all questions, you MUST vary sentence structure.\n" |
|
|
"- Randomly choose ONE of these patterns each time:\n" |
|
|
" (A) Start with the sound name (Answer) -> then the relation.\n" |
|
|
" (B) Start with the relation -> then the sound name (Answer).\n" |
|
|
" (C) Use an 'it`s...' style clause after the Answer.\n" |
|
|
" (D) Use a short, natural rephrase with different verbs (e.g., lasts, continues, stands out, comes through).\n" |
|
|
"- Do not always use 'The sound with the ... is ...' — that pattern should be rare.\n\n" |
|
|
"TASK HANDLING (infer from the Question):\n" |
|
|
"- COUNT questions (how many / count / number):\n" |
|
|
" * If Answer is numeric, write it EITHER as digits (e.g., 10) OR as a word (e.g., ten). Do NOT include both.\n" |
|
|
"- DURATION questions (longest/shortest):\n" |
|
|
" * Clearly state longest vs shortest, and use the Answer as the sound name. Do not include any numbers.\n" |
|
|
"- VOLUME questions (minimum/maximum loudness, quietest/loudest):\n" |
|
|
" * Match minimum vs maximum loudness and use the Answer as the sound name. No dB values.\n" |
|
|
"- ORDER questions (first/second/before/after/second-to-last):\n" |
|
|
" * Match the requested relation and use the Answer as the sound name.\n\n" |
|
|
"Return only the sentence." |
|
|
) |
|
|
|
|
|
user_prompt = ( |
|
|
f"Question: {question}\n" |
|
|
f"Answer: {correct_value}\n" |
|
|
"Rewrite the answer as a single, natural sentence that directly answers the question." |
|
|
) |
|
|
else: |
|
|
system_preamble = ( |
|
|
"You are a helpful assistant that converts multiple-choice QA pairs into natural language answers.\n" |
|
|
"CRITICAL RULES:\n" |
|
|
"1. Write as a human would naturally speak - vary sentence structure and avoid repetitive patterns\n" |
|
|
"2. Keep responses concise but natural and affirmative avoiding words like 'might/may' or 'could' - one clear sentence\n" |
|
|
"3. Do not mention 'among the options/among the following' even if the question mentions it. This natural language statement is supposed to be a direct answer.\n" |
|
|
"4. Do NOT invent sounds.\n" |
|
|
"5. Do not reason to answer the question, you're just supposed to provide the correct mcq answer as a natural language answer in a single sentence.\n" |
|
|
"Return only the natural language answer, nothing else." |
|
|
) |
|
|
user_prompt = ( |
|
|
f"Now, given the question: '{question}' and the correct answer: '{correct_value}', " |
|
|
f"write one natural-language answer as you would expect from a human." |
|
|
) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": system_preamble}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
] |
|
|
inputs = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=True, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt", |
|
|
).to(device) |
|
|
|
|
|
input_length = inputs.shape[1] |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model.generate( |
|
|
inputs, |
|
|
max_new_tokens=64, |
|
|
do_sample=True, |
|
|
temperature=0.8, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.05, |
|
|
no_repeat_ngram_size=3, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
generated_ids = output[0, input_length:] |
|
|
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
print(f"Model response: {response}") |
|
|
return response |
|
|
|
|
|
|
|
|
def detect_csv_format(df): |
|
|
""" |
|
|
Detect CSV layout and return column mappings. |
|
|
Supports: |
|
|
- original MCQ format |
|
|
- perturbed MCQ format |
|
|
- open-text format (question/answer present) |
|
|
""" |
|
|
columns = df.columns.tolist() |
|
|
|
|
|
if "correct" in columns and "id" in columns and "audio_path" in columns: |
|
|
|
|
|
return { |
|
|
"id_col": "id", |
|
|
"audio_path_col": "audio_path", |
|
|
"answer_col": "correct", |
|
|
"question_col": "question", |
|
|
"format_type": "original", |
|
|
} |
|
|
if "answer" in columns and "idx" in columns and "new_audio_path" in columns: |
|
|
|
|
|
return { |
|
|
"id_col": "idx", |
|
|
"audio_path_col": "new_audio_path", |
|
|
"answer_col": "answer", |
|
|
"question_col": "question", |
|
|
"format_type": "perturbed", |
|
|
} |
|
|
if "answer" in columns and "question" in columns: |
|
|
|
|
|
return { |
|
|
"id_col": "id" if "id" in columns else None, |
|
|
"audio_path_col": "audio_path" if "audio_path" in columns else None, |
|
|
"answer_col": "answer", |
|
|
"question_col": "question", |
|
|
"format_type": "open_text", |
|
|
} |
|
|
|
|
|
raise ValueError(f"Unknown CSV format. Columns found: {columns}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Convert CSV to NL answers (MCQ or open-text) using meta-llama/Llama-3.1-8B-Instruct" |
|
|
) |
|
|
parser.add_argument("--input", required=True, help="Input CSV file") |
|
|
parser.add_argument("--output", required=False, help="Output CSV file (defaults to input for in-place append)") |
|
|
parser.add_argument( |
|
|
"--mode", |
|
|
required=True, |
|
|
choices=["mcq", "open_text"], |
|
|
help="Conversion mode: mcq -> convert MCQ correct option to natural answer; open_text -> rewrite provided short answer to a natural sentence", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--task", |
|
|
required=True, |
|
|
choices=["count", "duration", "order", "volume"], |
|
|
help="Task type this CSV belongs to (used for bookkeeping/logging)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--one_word_ratio", |
|
|
type=float, |
|
|
default=0.2, |
|
|
help="Fraction of samples to output as just the normalized one-word/phrase answer (no LLM forward pass). Default 0.2", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=123, |
|
|
help="Random seed for reproducible one_word sampling.", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
random.seed(args.seed) |
|
|
|
|
|
print("Loading meta-llama/Llama-3.1-8B-Instruct tokenizer and model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", use_fast=False) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"meta-llama/Llama-3.1-8B-Instruct", |
|
|
torch_dtype="auto", |
|
|
device_map="auto", |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
df = pd.read_csv(args.input) |
|
|
|
|
|
|
|
|
format_info = detect_csv_format(df) |
|
|
print(f"Detected CSV format: {format_info['format_type']}") |
|
|
|
|
|
|
|
|
if args.mode == "mcq" and format_info["format_type"] == "open_text": |
|
|
raise ValueError( |
|
|
"Requested mode=mcq but input appears to be open_text format. Use --mode open_text or supply an MCQ CSV." |
|
|
) |
|
|
if args.mode == "open_text" and format_info["format_type"] != "open_text": |
|
|
raise ValueError( |
|
|
"Requested mode=open_text but input does not appear to be open_text format. Use --mode mcq or supply an open_text CSV." |
|
|
) |
|
|
|
|
|
output_path = args.output if args.output else args.input |
|
|
|
|
|
nl_rows = [] |
|
|
device = model.device |
|
|
|
|
|
for i, row in df.iterrows(): |
|
|
question = row[format_info["question_col"]] |
|
|
|
|
|
|
|
|
if format_info["format_type"] == "open_text": |
|
|
correct_value = row[format_info["answer_col"]] |
|
|
else: |
|
|
correct_letter = row[format_info["answer_col"]] |
|
|
option_map = {"A": "optionA", "B": "optionB", "C": "optionC", "D": "optionD"} |
|
|
correct_value = row[option_map[correct_letter]] |
|
|
|
|
|
|
|
|
correct_value = convert_to_natural_phrase(correct_value) |
|
|
|
|
|
print(f"[{i+1}/{len(df)}] Q: {question} | Ans: {correct_value}") |
|
|
|
|
|
|
|
|
if random.random() < args.one_word_ratio: |
|
|
nl_answer = correct_value |
|
|
print(f"Skipped LLM (one_word_ratio). Output: {nl_answer}") |
|
|
else: |
|
|
nl_answer = generate_answer( |
|
|
tokenizer, |
|
|
model, |
|
|
question, |
|
|
correct_value, |
|
|
device, |
|
|
mode=("open_text" if format_info["format_type"] == "open_text" else "mcq"), |
|
|
) |
|
|
|
|
|
nl_rows.append( |
|
|
{ |
|
|
"question": question, |
|
|
"id": row[format_info["id_col"]] if format_info.get("id_col") and format_info["id_col"] in row else None, |
|
|
"audio_path": row[format_info["audio_path_col"]] |
|
|
if format_info.get("audio_path_col") |
|
|
else None, |
|
|
"original_answer": correct_value, |
|
|
"open_text_answer": nl_answer, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
nl_df = pd.DataFrame(nl_rows) |
|
|
df["open_text_answer"] = nl_df["open_text_answer"] |
|
|
df.to_csv(output_path, index=False) |
|
|
print(f"Appended natural language answers to {output_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|