Spaces:
Paused
Paused
| """ | |
| Utility script to build a clean reply-generation dataset for Models 1 & 2. | |
| Goal: | |
| - Tạo ra dataset mà output luôn là **câu trả lời từ phía Nam**, | |
| đã được chuẩn hoá theo prompt wingman (anh/em, 1 câu, ≤25 từ). | |
| Ý tưởng: | |
| - Đọc `new_data_selected.csv` gốc (user_text, partner_text, trigger_*, move_*) | |
| - Tính trigger/move chính (như khi fine-tune trigger detector) | |
| - Gom hội thoại: `Male: user_text ||| Female: partner_text` | |
| - Gọi wingman prompt-based (`ReplySuggestionService`) để sinh `male_reply` | |
| - Ghi ra CSV mới: `conversation,trigger,move,male_reply` | |
| Usage (local hoặc trong Spaces terminal): | |
| python build_reply_dataset.py \ | |
| --data_path new_data_selected.csv \ | |
| --output_path reply_training_data.csv | |
| Lưu ý: | |
| - Script dùng Hugging Face Inference API, cần HF_TOKEN có quyền call model (như trong reply_service.py). | |
| """ | |
| import argparse | |
| import os | |
| from typing import List | |
| import pandas as pd | |
| from reply_service import ReplySuggestionService | |
| def _detect_columns(df: pd.DataFrame, prefix: str) -> List[str]: | |
| cols = [col for col in df.columns if col.startswith(prefix)] | |
| if not cols: | |
| raise ValueError(f"No columns found with prefix '{prefix}'") | |
| return cols | |
| def _get_active_label(row: pd.Series, cols: List[str], prefix: str) -> str: | |
| """Lấy nhãn đầu tiên có giá trị 1, nếu không có thì trả về 'neutral'.""" | |
| for col in cols: | |
| if float(row.get(col, 0)) == 1.0: | |
| return col.replace(prefix, "") | |
| return "neutral" | |
| def _build_conversation(user_text: str, partner_text: str) -> str: | |
| user = (user_text or "").strip() | |
| partner = (partner_text or "").strip() | |
| if not user and not partner: | |
| return "" | |
| if user: | |
| return f"Male: {user} ||| Female: {partner}" | |
| return f"Female: {partner}" | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Build reply-generation dataset from new_data_selected.csv") | |
| parser.add_argument( | |
| "--data_path", | |
| type=str, | |
| default="new_data_selected.csv", | |
| help="Đường dẫn đến CSV gốc (new_data_selected.csv)", | |
| ) | |
| parser.add_argument( | |
| "--output_path", | |
| type=str, | |
| default="reply_training_data.csv", | |
| help="Đường dẫn file CSV output chứa male_reply", | |
| ) | |
| parser.add_argument( | |
| "--max_rows", | |
| type=int, | |
| default=-1, | |
| help="Giới hạn số dòng xử lý (debug). -1 = dùng toàn bộ", | |
| ) | |
| args = parser.parse_args() | |
| print(f"[BUILD_DATASET] Loading dataset from {args.data_path}") | |
| df = pd.read_csv(args.data_path) | |
| if args.max_rows > 0: | |
| df = df.head(args.max_rows) | |
| print(f"[BUILD_DATASET] Using first {len(df)} rows for reply synthesis") | |
| trigger_cols = _detect_columns(df, "trigger_") | |
| move_cols = _detect_columns(df, "move_") | |
| # Khởi tạo wingman prompt-based | |
| print("[BUILD_DATASET] Initializing ReplySuggestionService (Inference API)...") | |
| reply_service = ReplySuggestionService() | |
| out_rows = [] | |
| for idx, row in df.iterrows(): | |
| user_text = str(row.get("user_text", "") or "") | |
| partner_text = str(row.get("partner_text", "") or "") | |
| conversation = _build_conversation(user_text, partner_text) | |
| if not conversation: | |
| continue | |
| trigger = _get_active_label(row, trigger_cols, "trigger_") | |
| move = _get_active_label(row, move_cols, "move_") | |
| try: | |
| male_reply = reply_service.suggest_reply( | |
| male=user_text, | |
| female=partner_text, | |
| tone=move, | |
| intent=trigger, | |
| ) | |
| except Exception as exc: | |
| print(f"[BUILD_DATASET] Row {idx}: reply generation failed: {exc}") | |
| continue | |
| out_rows.append( | |
| { | |
| "conversation": conversation, | |
| "user_text": user_text, | |
| "partner_text": partner_text, | |
| "trigger": trigger, | |
| "move": move, | |
| "male_reply": male_reply, | |
| } | |
| ) | |
| if (idx + 1) % 50 == 0: | |
| print(f"[BUILD_DATASET] Processed {idx + 1} rows...") | |
| out_df = pd.DataFrame(out_rows) | |
| out_df.to_csv(args.output_path, index=False) | |
| print(f"[BUILD_DATASET] Saved {len(out_df)} rows to {args.output_path}") | |
| if __name__ == "__main__": | |
| main() | |