TREA_2.0_codebase / llm_answer_generator.py
malay-36's picture
Upload folder using huggingface_hub
fec9168 verified
import pandas as pd
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import random
# Convert MCQ CSV to NL answers using a text-only LLM (meta-llama/Llama-3.1-8B-Instruct)
# Adds: (1) stronger LLM-driven variability for duration/volume in open_text mode via system prompt
# (2) --one_word_ratio (default 0.2) to skip forward pass for a fraction of rows,
# outputting the normalized (underscore-removed) answer only.
def convert_to_natural_phrase(val):
"""Convert underscore-separated tokens to natural phrases."""
if isinstance(val, str) and "_" in val:
val = val.replace("_", " ")
return val
def generate_answer(tokenizer, model, question, correct_value, device, mode="mcq"):
"""Generate a natural language answer using a text-only LLM.
mode: "mcq" (default) uses the original MCQ-oriented prompt.
"open_text" uses a direct rewrite prompt for provided question/answer pairs.
"""
correct_value = convert_to_natural_phrase(correct_value)
if mode == "open_text":
system_preamble = (
"You convert (Question, short Answer) into EXACTLY ONE natural English sentence that answers the Question.\n\n"
"HARD RULES:\n"
"- Output exactly ONE sentence. No newlines, no bullet points, no labels, no quotes.\n"
"- Use ONLY the provided Answer content as the factual answer; do not add any new facts.\n"
"- Be concise and direct.\n"
"- Do NOT include any numbers unless the question is a COUNT question.\n"
"- Vary phrasing strongly across items; avoid repeating the same structure.\n\n"
"VARIABILITY REQUIREMENT (IMPORTANT):\n"
"- For all questions, you MUST vary sentence structure.\n"
"- Randomly choose ONE of these patterns each time:\n"
" (A) Start with the sound name (Answer) -> then the relation.\n"
" (B) Start with the relation -> then the sound name (Answer).\n"
" (C) Use an 'it`s...' style clause after the Answer.\n"
" (D) Use a short, natural rephrase with different verbs (e.g., lasts, continues, stands out, comes through).\n"
"- Do not always use 'The sound with the ... is ...' — that pattern should be rare.\n\n"
"TASK HANDLING (infer from the Question):\n"
"- COUNT questions (how many / count / number):\n"
" * If Answer is numeric, write it EITHER as digits (e.g., 10) OR as a word (e.g., ten). Do NOT include both.\n"
"- DURATION questions (longest/shortest):\n"
" * Clearly state longest vs shortest, and use the Answer as the sound name. Do not include any numbers.\n"
"- VOLUME questions (minimum/maximum loudness, quietest/loudest):\n"
" * Match minimum vs maximum loudness and use the Answer as the sound name. No dB values.\n"
"- ORDER questions (first/second/before/after/second-to-last):\n"
" * Match the requested relation and use the Answer as the sound name.\n\n"
"Return only the sentence."
)
user_prompt = (
f"Question: {question}\n"
f"Answer: {correct_value}\n"
"Rewrite the answer as a single, natural sentence that directly answers the question."
)
else:
system_preamble = (
"You are a helpful assistant that converts multiple-choice QA pairs into natural language answers.\n"
"CRITICAL RULES:\n"
"1. Write as a human would naturally speak - vary sentence structure and avoid repetitive patterns\n"
"2. Keep responses concise but natural and affirmative avoiding words like 'might/may' or 'could' - one clear sentence\n"
"3. Do not mention 'among the options/among the following' even if the question mentions it. This natural language statement is supposed to be a direct answer.\n"
"4. Do NOT invent sounds.\n"
"5. Do not reason to answer the question, you're just supposed to provide the correct mcq answer as a natural language answer in a single sentence.\n"
"Return only the natural language answer, nothing else."
)
user_prompt = (
f"Now, given the question: '{question}' and the correct answer: '{correct_value}', "
f"write one natural-language answer as you would expect from a human."
)
# Chat format
messages = [
{"role": "system", "content": system_preamble},
{"role": "user", "content": user_prompt},
]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(device)
input_length = inputs.shape[1]
with torch.no_grad():
output = model.generate(
inputs,
max_new_tokens=64,
do_sample=True,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.05,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated_ids = output[0, input_length:]
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
print(f"Model response: {response}")
return response
def detect_csv_format(df):
"""
Detect CSV layout and return column mappings.
Supports:
- original MCQ format
- perturbed MCQ format
- open-text format (question/answer present)
"""
columns = df.columns.tolist()
if "correct" in columns and "id" in columns and "audio_path" in columns:
# Original format (count.csv)
return {
"id_col": "id",
"audio_path_col": "audio_path",
"answer_col": "correct",
"question_col": "question",
"format_type": "original",
}
if "answer" in columns and "idx" in columns and "new_audio_path" in columns:
# Perturbed format (count_perturbed.csv)
return {
"id_col": "idx",
"audio_path_col": "new_audio_path",
"answer_col": "answer",
"question_col": "question",
"format_type": "perturbed",
}
if "answer" in columns and "question" in columns:
# Open-text format
return {
"id_col": "id" if "id" in columns else None,
"audio_path_col": "audio_path" if "audio_path" in columns else None,
"answer_col": "answer",
"question_col": "question",
"format_type": "open_text",
}
raise ValueError(f"Unknown CSV format. Columns found: {columns}")
def main():
parser = argparse.ArgumentParser(
description="Convert CSV to NL answers (MCQ or open-text) using meta-llama/Llama-3.1-8B-Instruct"
)
parser.add_argument("--input", required=True, help="Input CSV file")
parser.add_argument("--output", required=False, help="Output CSV file (defaults to input for in-place append)")
parser.add_argument(
"--mode",
required=True,
choices=["mcq", "open_text"],
help="Conversion mode: mcq -> convert MCQ correct option to natural answer; open_text -> rewrite provided short answer to a natural sentence",
)
parser.add_argument(
"--task",
required=True,
choices=["count", "duration", "order", "volume"],
help="Task type this CSV belongs to (used for bookkeeping/logging)",
)
# NEW: one-word skipping
parser.add_argument(
"--one_word_ratio",
type=float,
default=0.2,
help="Fraction of samples to output as just the normalized one-word/phrase answer (no LLM forward pass). Default 0.2",
)
parser.add_argument(
"--seed",
type=int,
default=123,
help="Random seed for reproducible one_word sampling.",
)
args = parser.parse_args()
random.seed(args.seed)
print("Loading meta-llama/Llama-3.1-8B-Instruct tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype="auto",
device_map="auto",
)
model.eval()
df = pd.read_csv(args.input)
# Detect CSV format and get column mappings
format_info = detect_csv_format(df)
print(f"Detected CSV format: {format_info['format_type']}")
# Validate requested mode against detected CSV format
if args.mode == "mcq" and format_info["format_type"] == "open_text":
raise ValueError(
"Requested mode=mcq but input appears to be open_text format. Use --mode open_text or supply an MCQ CSV."
)
if args.mode == "open_text" and format_info["format_type"] != "open_text":
raise ValueError(
"Requested mode=open_text but input does not appear to be open_text format. Use --mode mcq or supply an open_text CSV."
)
output_path = args.output if args.output else args.input
nl_rows = []
device = model.device
for i, row in df.iterrows():
question = row[format_info["question_col"]]
# Resolve correct_value from CSV format
if format_info["format_type"] == "open_text":
correct_value = row[format_info["answer_col"]]
else:
correct_letter = row[format_info["answer_col"]]
option_map = {"A": "optionA", "B": "optionB", "C": "optionC", "D": "optionD"}
correct_value = row[option_map[correct_letter]]
# Normalize underscores BEFORE deciding one_word skip
correct_value = convert_to_natural_phrase(correct_value)
print(f"[{i+1}/{len(df)}] Q: {question} | Ans: {correct_value}")
# 20%: one-word/phrase answer, no forward pass
if random.random() < args.one_word_ratio:
nl_answer = correct_value
print(f"Skipped LLM (one_word_ratio). Output: {nl_answer}")
else:
nl_answer = generate_answer(
tokenizer,
model,
question,
correct_value,
device,
mode=("open_text" if format_info["format_type"] == "open_text" else "mcq"),
)
nl_rows.append(
{
"question": question,
"id": row[format_info["id_col"]] if format_info.get("id_col") and format_info["id_col"] in row else None,
"audio_path": row[format_info["audio_path_col"]]
if format_info.get("audio_path_col")
else None,
"original_answer": correct_value,
"open_text_answer": nl_answer,
}
)
# Merge back as new column to the original CSV to preserve all fields
nl_df = pd.DataFrame(nl_rows)
df["open_text_answer"] = nl_df["open_text_answer"]
df.to_csv(output_path, index=False)
print(f"Appended natural language answers to {output_path}")
if __name__ == "__main__":
main()