File size: 11,167 Bytes
fec9168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import pandas as pd
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import random

# Convert MCQ CSV to NL answers using a text-only LLM (meta-llama/Llama-3.1-8B-Instruct)
# Adds: (1) stronger LLM-driven variability for duration/volume in open_text mode via system prompt
#       (2) --one_word_ratio (default 0.2) to skip forward pass for a fraction of rows,
#           outputting the normalized (underscore-removed) answer only.


def convert_to_natural_phrase(val):
    """Convert underscore-separated tokens to natural phrases."""
    if isinstance(val, str) and "_" in val:
        val = val.replace("_", " ")
    return val


def generate_answer(tokenizer, model, question, correct_value, device, mode="mcq"):
    """Generate a natural language answer using a text-only LLM.

    mode: "mcq" (default) uses the original MCQ-oriented prompt.
          "open_text" uses a direct rewrite prompt for provided question/answer pairs.
    """
    correct_value = convert_to_natural_phrase(correct_value)

    if mode == "open_text":
        system_preamble = (
            "You convert (Question, short Answer) into EXACTLY ONE natural English sentence that answers the Question.\n\n"
            "HARD RULES:\n"
            "- Output exactly ONE sentence. No newlines, no bullet points, no labels, no quotes.\n"
            "- Use ONLY the provided Answer content as the factual answer; do not add any new facts.\n"
            "- Be concise and direct.\n"
            "- Do NOT include any numbers unless the question is a COUNT question.\n"
            "- Vary phrasing strongly across items; avoid repeating the same structure.\n\n"
            "VARIABILITY REQUIREMENT (IMPORTANT):\n"
            "- For all questions, you MUST vary sentence structure.\n"
            "- Randomly choose ONE of these patterns each time:\n"
            "  (A) Start with the sound name (Answer) -> then the relation.\n"
            "  (B) Start with the relation -> then the sound name (Answer).\n"
            "  (C) Use an 'it`s...' style clause after the Answer.\n"
            "  (D) Use a short, natural rephrase with different verbs (e.g., lasts, continues, stands out, comes through).\n"
            "- Do not always use 'The sound with the ... is ...' — that pattern should be rare.\n\n"
            "TASK HANDLING (infer from the Question):\n"
            "- COUNT questions (how many / count / number):\n"
            "  * If Answer is numeric, write it EITHER as digits (e.g., 10) OR as a word (e.g., ten). Do NOT include both.\n"
            "- DURATION questions (longest/shortest):\n"
            "  * Clearly state longest vs shortest, and use the Answer as the sound name. Do not include any numbers.\n"
            "- VOLUME questions (minimum/maximum loudness, quietest/loudest):\n"
            "  * Match minimum vs maximum loudness and use the Answer as the sound name. No dB values.\n"
            "- ORDER questions (first/second/before/after/second-to-last):\n"
            "  * Match the requested relation and use the Answer as the sound name.\n\n"
            "Return only the sentence."
        )

        user_prompt = (
            f"Question: {question}\n"
            f"Answer: {correct_value}\n"
            "Rewrite the answer as a single, natural sentence that directly answers the question."
        )
    else:
        system_preamble = (
            "You are a helpful assistant that converts multiple-choice QA pairs into natural language answers.\n"
            "CRITICAL RULES:\n"
            "1. Write as a human would naturally speak - vary sentence structure and avoid repetitive patterns\n"
            "2. Keep responses concise but natural and affirmative avoiding words like 'might/may' or 'could' - one clear sentence\n"
            "3. Do not mention 'among the options/among the following' even if the question mentions it. This natural language statement is supposed to be a direct answer.\n"
            "4. Do NOT invent sounds.\n"
            "5. Do not reason to answer the question, you're just supposed to provide the correct mcq answer as a natural language answer in a single sentence.\n"
            "Return only the natural language answer, nothing else."
        )
        user_prompt = (
            f"Now, given the question: '{question}' and the correct answer: '{correct_value}', "
            f"write one natural-language answer as you would expect from a human."
        )

    # Chat format
    messages = [
        {"role": "system", "content": system_preamble},
        {"role": "user", "content": user_prompt},
    ]
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(device)

    input_length = inputs.shape[1]

    with torch.no_grad():
        output = model.generate(
            inputs,
            max_new_tokens=64,
            do_sample=True,
            temperature=0.8,
            top_p=0.9,
            repetition_penalty=1.05,
            no_repeat_ngram_size=3,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    generated_ids = output[0, input_length:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
    print(f"Model response: {response}")
    return response


def detect_csv_format(df):
    """
    Detect CSV layout and return column mappings.
    Supports:
      - original MCQ format
      - perturbed MCQ format
      - open-text format (question/answer present)
    """
    columns = df.columns.tolist()

    if "correct" in columns and "id" in columns and "audio_path" in columns:
        # Original format (count.csv)
        return {
            "id_col": "id",
            "audio_path_col": "audio_path",
            "answer_col": "correct",
            "question_col": "question",
            "format_type": "original",
        }
    if "answer" in columns and "idx" in columns and "new_audio_path" in columns:
        # Perturbed format (count_perturbed.csv)
        return {
            "id_col": "idx",
            "audio_path_col": "new_audio_path",
            "answer_col": "answer",
            "question_col": "question",
            "format_type": "perturbed",
        }
    if "answer" in columns and "question" in columns:
        # Open-text format
        return {
            "id_col": "id" if "id" in columns else None,
            "audio_path_col": "audio_path" if "audio_path" in columns else None,
            "answer_col": "answer",
            "question_col": "question",
            "format_type": "open_text",
        }

    raise ValueError(f"Unknown CSV format. Columns found: {columns}")


def main():
    parser = argparse.ArgumentParser(
        description="Convert CSV to NL answers (MCQ or open-text) using meta-llama/Llama-3.1-8B-Instruct"
    )
    parser.add_argument("--input", required=True, help="Input CSV file")
    parser.add_argument("--output", required=False, help="Output CSV file (defaults to input for in-place append)")
    parser.add_argument(
        "--mode",
        required=True,
        choices=["mcq", "open_text"],
        help="Conversion mode: mcq -> convert MCQ correct option to natural answer; open_text -> rewrite provided short answer to a natural sentence",
    )
    parser.add_argument(
        "--task",
        required=True,
        choices=["count", "duration", "order", "volume"],
        help="Task type this CSV belongs to (used for bookkeeping/logging)",
    )

    # NEW: one-word skipping
    parser.add_argument(
        "--one_word_ratio",
        type=float,
        default=0.2,
        help="Fraction of samples to output as just the normalized one-word/phrase answer (no LLM forward pass). Default 0.2",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=123,
        help="Random seed for reproducible one_word sampling.",
    )

    args = parser.parse_args()
    random.seed(args.seed)

    print("Loading meta-llama/Llama-3.1-8B-Instruct tokenizer and model...")
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-3.1-8B-Instruct",
        torch_dtype="auto",
        device_map="auto",
    )
    model.eval()

    df = pd.read_csv(args.input)

    # Detect CSV format and get column mappings
    format_info = detect_csv_format(df)
    print(f"Detected CSV format: {format_info['format_type']}")

    # Validate requested mode against detected CSV format
    if args.mode == "mcq" and format_info["format_type"] == "open_text":
        raise ValueError(
            "Requested mode=mcq but input appears to be open_text format. Use --mode open_text or supply an MCQ CSV."
        )
    if args.mode == "open_text" and format_info["format_type"] != "open_text":
        raise ValueError(
            "Requested mode=open_text but input does not appear to be open_text format. Use --mode mcq or supply an open_text CSV."
        )

    output_path = args.output if args.output else args.input

    nl_rows = []
    device = model.device

    for i, row in df.iterrows():
        question = row[format_info["question_col"]]

        # Resolve correct_value from CSV format
        if format_info["format_type"] == "open_text":
            correct_value = row[format_info["answer_col"]]
        else:
            correct_letter = row[format_info["answer_col"]]
            option_map = {"A": "optionA", "B": "optionB", "C": "optionC", "D": "optionD"}
            correct_value = row[option_map[correct_letter]]

        # Normalize underscores BEFORE deciding one_word skip
        correct_value = convert_to_natural_phrase(correct_value)

        print(f"[{i+1}/{len(df)}] Q: {question} | Ans: {correct_value}")

        # 20%: one-word/phrase answer, no forward pass
        if random.random() < args.one_word_ratio:
            nl_answer = correct_value
            print(f"Skipped LLM (one_word_ratio). Output: {nl_answer}")
        else:
            nl_answer = generate_answer(
                tokenizer,
                model,
                question,
                correct_value,
                device,
                mode=("open_text" if format_info["format_type"] == "open_text" else "mcq"),
            )

        nl_rows.append(
            {
                "question": question,
                "id": row[format_info["id_col"]] if format_info.get("id_col") and format_info["id_col"] in row else None,
                "audio_path": row[format_info["audio_path_col"]]
                if format_info.get("audio_path_col")
                else None,
                "original_answer": correct_value,
                "open_text_answer": nl_answer,
            }
        )

    # Merge back as new column to the original CSV to preserve all fields
    nl_df = pd.DataFrame(nl_rows)
    df["open_text_answer"] = nl_df["open_text_answer"]
    df.to_csv(output_path, index=False)
    print(f"Appended natural language answers to {output_path}")


if __name__ == "__main__":
    main()