lovebird25 / build_reply_dataset.py
Paul
update
1a6e95f
"""
Utility script to build a clean reply-generation dataset for Models 1 & 2.
Goal:
- Tạo ra dataset mà output luôn là **câu trả lời từ phía Nam**,
đã được chuẩn hoá theo prompt wingman (anh/em, 1 câu, ≤25 từ).
Ý tưởng:
- Đọc `new_data_selected.csv` gốc (user_text, partner_text, trigger_*, move_*)
- Tính trigger/move chính (như khi fine-tune trigger detector)
- Gom hội thoại: `Male: user_text ||| Female: partner_text`
- Gọi wingman prompt-based (`ReplySuggestionService`) để sinh `male_reply`
- Ghi ra CSV mới: `conversation,trigger,move,male_reply`
Usage (local hoặc trong Spaces terminal):
python build_reply_dataset.py \
--data_path new_data_selected.csv \
--output_path reply_training_data.csv
Lưu ý:
- Script dùng Hugging Face Inference API, cần HF_TOKEN có quyền call model (như trong reply_service.py).
"""
import argparse
import os
from typing import List
import pandas as pd
from reply_service import ReplySuggestionService
def _detect_columns(df: pd.DataFrame, prefix: str) -> List[str]:
cols = [col for col in df.columns if col.startswith(prefix)]
if not cols:
raise ValueError(f"No columns found with prefix '{prefix}'")
return cols
def _get_active_label(row: pd.Series, cols: List[str], prefix: str) -> str:
"""Lấy nhãn đầu tiên có giá trị 1, nếu không có thì trả về 'neutral'."""
for col in cols:
if float(row.get(col, 0)) == 1.0:
return col.replace(prefix, "")
return "neutral"
def _build_conversation(user_text: str, partner_text: str) -> str:
user = (user_text or "").strip()
partner = (partner_text or "").strip()
if not user and not partner:
return ""
if user:
return f"Male: {user} ||| Female: {partner}"
return f"Female: {partner}"
def main():
parser = argparse.ArgumentParser(description="Build reply-generation dataset from new_data_selected.csv")
parser.add_argument(
"--data_path",
type=str,
default="new_data_selected.csv",
help="Đường dẫn đến CSV gốc (new_data_selected.csv)",
)
parser.add_argument(
"--output_path",
type=str,
default="reply_training_data.csv",
help="Đường dẫn file CSV output chứa male_reply",
)
parser.add_argument(
"--max_rows",
type=int,
default=-1,
help="Giới hạn số dòng xử lý (debug). -1 = dùng toàn bộ",
)
args = parser.parse_args()
print(f"[BUILD_DATASET] Loading dataset from {args.data_path}")
df = pd.read_csv(args.data_path)
if args.max_rows > 0:
df = df.head(args.max_rows)
print(f"[BUILD_DATASET] Using first {len(df)} rows for reply synthesis")
trigger_cols = _detect_columns(df, "trigger_")
move_cols = _detect_columns(df, "move_")
# Khởi tạo wingman prompt-based
print("[BUILD_DATASET] Initializing ReplySuggestionService (Inference API)...")
reply_service = ReplySuggestionService()
out_rows = []
for idx, row in df.iterrows():
user_text = str(row.get("user_text", "") or "")
partner_text = str(row.get("partner_text", "") or "")
conversation = _build_conversation(user_text, partner_text)
if not conversation:
continue
trigger = _get_active_label(row, trigger_cols, "trigger_")
move = _get_active_label(row, move_cols, "move_")
try:
male_reply = reply_service.suggest_reply(
male=user_text,
female=partner_text,
tone=move,
intent=trigger,
)
except Exception as exc:
print(f"[BUILD_DATASET] Row {idx}: reply generation failed: {exc}")
continue
out_rows.append(
{
"conversation": conversation,
"user_text": user_text,
"partner_text": partner_text,
"trigger": trigger,
"move": move,
"male_reply": male_reply,
}
)
if (idx + 1) % 50 == 0:
print(f"[BUILD_DATASET] Processed {idx + 1} rows...")
out_df = pd.DataFrame(out_rows)
out_df.to_csv(args.output_path, index=False)
print(f"[BUILD_DATASET] Saved {len(out_df)} rows to {args.output_path}")
if __name__ == "__main__":
main()