Spaces:
Paused
Paused
Paul commited on
Commit ·
1a6e95f
1
Parent(s): bbce197
update
Browse files- build_reply_dataset.py +140 -0
- finetune_model.py +52 -27
build_reply_dataset.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility script to build a clean reply-generation dataset for Models 1 & 2.
|
| 3 |
+
|
| 4 |
+
Goal:
|
| 5 |
+
- Tạo ra dataset mà output luôn là **câu trả lời từ phía Nam**,
|
| 6 |
+
đã được chuẩn hoá theo prompt wingman (anh/em, 1 câu, ≤25 từ).
|
| 7 |
+
|
| 8 |
+
Ý tưởng:
|
| 9 |
+
- Đọc `new_data_selected.csv` gốc (user_text, partner_text, trigger_*, move_*)
|
| 10 |
+
- Tính trigger/move chính (như khi fine-tune trigger detector)
|
| 11 |
+
- Gom hội thoại: `Male: user_text ||| Female: partner_text`
|
| 12 |
+
- Gọi wingman prompt-based (`ReplySuggestionService`) để sinh `male_reply`
|
| 13 |
+
- Ghi ra CSV mới: `conversation,trigger,move,male_reply`
|
| 14 |
+
|
| 15 |
+
Usage (local hoặc trong Spaces terminal):
|
| 16 |
+
python build_reply_dataset.py \
|
| 17 |
+
--data_path new_data_selected.csv \
|
| 18 |
+
--output_path reply_training_data.csv
|
| 19 |
+
|
| 20 |
+
Lưu ý:
|
| 21 |
+
- Script dùng Hugging Face Inference API, cần HF_TOKEN có quyền call model (như trong reply_service.py).
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import os
|
| 26 |
+
from typing import List
|
| 27 |
+
|
| 28 |
+
import pandas as pd
|
| 29 |
+
|
| 30 |
+
from reply_service import ReplySuggestionService
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _detect_columns(df: pd.DataFrame, prefix: str) -> List[str]:
|
| 34 |
+
cols = [col for col in df.columns if col.startswith(prefix)]
|
| 35 |
+
if not cols:
|
| 36 |
+
raise ValueError(f"No columns found with prefix '{prefix}'")
|
| 37 |
+
return cols
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _get_active_label(row: pd.Series, cols: List[str], prefix: str) -> str:
|
| 41 |
+
"""Lấy nhãn đầu tiên có giá trị 1, nếu không có thì trả về 'neutral'."""
|
| 42 |
+
for col in cols:
|
| 43 |
+
if float(row.get(col, 0)) == 1.0:
|
| 44 |
+
return col.replace(prefix, "")
|
| 45 |
+
return "neutral"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _build_conversation(user_text: str, partner_text: str) -> str:
|
| 49 |
+
user = (user_text or "").strip()
|
| 50 |
+
partner = (partner_text or "").strip()
|
| 51 |
+
if not user and not partner:
|
| 52 |
+
return ""
|
| 53 |
+
if user:
|
| 54 |
+
return f"Male: {user} ||| Female: {partner}"
|
| 55 |
+
return f"Female: {partner}"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main():
|
| 59 |
+
parser = argparse.ArgumentParser(description="Build reply-generation dataset from new_data_selected.csv")
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--data_path",
|
| 62 |
+
type=str,
|
| 63 |
+
default="new_data_selected.csv",
|
| 64 |
+
help="Đường dẫn đến CSV gốc (new_data_selected.csv)",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--output_path",
|
| 68 |
+
type=str,
|
| 69 |
+
default="reply_training_data.csv",
|
| 70 |
+
help="Đường dẫn file CSV output chứa male_reply",
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--max_rows",
|
| 74 |
+
type=int,
|
| 75 |
+
default=-1,
|
| 76 |
+
help="Giới hạn số dòng xử lý (debug). -1 = dùng toàn bộ",
|
| 77 |
+
)
|
| 78 |
+
args = parser.parse_args()
|
| 79 |
+
|
| 80 |
+
print(f"[BUILD_DATASET] Loading dataset from {args.data_path}")
|
| 81 |
+
df = pd.read_csv(args.data_path)
|
| 82 |
+
|
| 83 |
+
if args.max_rows > 0:
|
| 84 |
+
df = df.head(args.max_rows)
|
| 85 |
+
print(f"[BUILD_DATASET] Using first {len(df)} rows for reply synthesis")
|
| 86 |
+
|
| 87 |
+
trigger_cols = _detect_columns(df, "trigger_")
|
| 88 |
+
move_cols = _detect_columns(df, "move_")
|
| 89 |
+
|
| 90 |
+
# Khởi tạo wingman prompt-based
|
| 91 |
+
print("[BUILD_DATASET] Initializing ReplySuggestionService (Inference API)...")
|
| 92 |
+
reply_service = ReplySuggestionService()
|
| 93 |
+
|
| 94 |
+
out_rows = []
|
| 95 |
+
|
| 96 |
+
for idx, row in df.iterrows():
|
| 97 |
+
user_text = str(row.get("user_text", "") or "")
|
| 98 |
+
partner_text = str(row.get("partner_text", "") or "")
|
| 99 |
+
conversation = _build_conversation(user_text, partner_text)
|
| 100 |
+
|
| 101 |
+
if not conversation:
|
| 102 |
+
continue
|
| 103 |
+
|
| 104 |
+
trigger = _get_active_label(row, trigger_cols, "trigger_")
|
| 105 |
+
move = _get_active_label(row, move_cols, "move_")
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
male_reply = reply_service.suggest_reply(
|
| 109 |
+
male=user_text,
|
| 110 |
+
female=partner_text,
|
| 111 |
+
tone=move,
|
| 112 |
+
intent=trigger,
|
| 113 |
+
)
|
| 114 |
+
except Exception as exc:
|
| 115 |
+
print(f"[BUILD_DATASET] Row {idx}: reply generation failed: {exc}")
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
out_rows.append(
|
| 119 |
+
{
|
| 120 |
+
"conversation": conversation,
|
| 121 |
+
"user_text": user_text,
|
| 122 |
+
"partner_text": partner_text,
|
| 123 |
+
"trigger": trigger,
|
| 124 |
+
"move": move,
|
| 125 |
+
"male_reply": male_reply,
|
| 126 |
+
}
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if (idx + 1) % 50 == 0:
|
| 130 |
+
print(f"[BUILD_DATASET] Processed {idx + 1} rows...")
|
| 131 |
+
|
| 132 |
+
out_df = pd.DataFrame(out_rows)
|
| 133 |
+
out_df.to_csv(args.output_path, index=False)
|
| 134 |
+
print(f"[BUILD_DATASET] Saved {len(out_df)} rows to {args.output_path}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
main()
|
| 139 |
+
|
| 140 |
+
|
finetune_model.py
CHANGED
|
@@ -66,30 +66,56 @@ def build_instruction(conversation: str, trigger: str, move: str, persona: str)
|
|
| 66 |
|
| 67 |
|
| 68 |
def prepare_training_data(df, use_history=True, persona="default"):
|
| 69 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
training_data = []
|
| 71 |
conversation_history = []
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
trigger_cols = [col for col in df.columns if col.startswith("trigger_")]
|
| 74 |
move_cols = [col for col in df.columns if col.startswith("move_")]
|
| 75 |
-
|
| 76 |
-
for
|
| 77 |
-
user_text = str(row[
|
| 78 |
-
partner_text = str(row[
|
| 79 |
-
|
| 80 |
-
# Skip rows with invalid data
|
| 81 |
if not partner_text or partner_text.strip() == "_":
|
| 82 |
continue
|
| 83 |
-
|
| 84 |
-
# Get active triggers and moves
|
| 85 |
active_triggers = get_active_labels(row, trigger_cols)
|
| 86 |
active_moves = get_active_labels(row, move_cols)
|
| 87 |
-
|
| 88 |
-
# Format: Use only the first active trigger/move (highest priority)
|
| 89 |
trigger = active_triggers[0] if active_triggers[0] != "none" else "neutral"
|
| 90 |
move = active_moves[0] if active_moves[0] != "none" else "neutral"
|
| 91 |
-
|
| 92 |
-
# Build conversation context
|
| 93 |
if use_history and conversation_history:
|
| 94 |
history_str = "\n".join(conversation_history)
|
| 95 |
if user_text and user_text.strip() != "_":
|
|
@@ -102,28 +128,27 @@ def prepare_training_data(df, use_history=True, persona="default"):
|
|
| 102 |
conversation = f"Male: {user_text} ||| Female: {partner_text}"
|
| 103 |
else:
|
| 104 |
conversation = f"Female: {partner_text}"
|
| 105 |
-
|
| 106 |
prompt = build_instruction(conversation, trigger, move, persona)
|
| 107 |
-
|
| 108 |
response = partner_text.strip()
|
| 109 |
-
|
| 110 |
-
training_data.append(
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
if user_text and user_text.strip() != "_":
|
| 118 |
conversation_history.append(f"Male: {user_text}")
|
| 119 |
if partner_text and partner_text.strip() != "_":
|
| 120 |
conversation_history.append(f"Female: {partner_text}")
|
| 121 |
-
|
| 122 |
-
# Limit history length
|
| 123 |
max_history = 4
|
| 124 |
if len(conversation_history) > max_history:
|
| 125 |
conversation_history = conversation_history[-max_history:]
|
| 126 |
-
|
| 127 |
return training_data
|
| 128 |
|
| 129 |
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
def prepare_training_data(df, use_history=True, persona="default"):
|
| 69 |
+
"""
|
| 70 |
+
Prepare data for fine-tuning.
|
| 71 |
+
|
| 72 |
+
Nếu dataset đã có cột `male_reply` (build bởi build_reply_dataset.py) thì dùng:
|
| 73 |
+
conversation, trigger, move, male_reply
|
| 74 |
+
Làm ground-truth chuẩn cho reply từ phía Nam.
|
| 75 |
+
Nếu không, fallback về logic cũ dựa trên user_text / partner_text (ít lý tưởng hơn).
|
| 76 |
+
"""
|
| 77 |
training_data = []
|
| 78 |
conversation_history = []
|
| 79 |
+
|
| 80 |
+
has_clean_reply = {"conversation", "trigger", "move", "male_reply"}.issubset(set(df.columns))
|
| 81 |
+
|
| 82 |
+
if has_clean_reply:
|
| 83 |
+
for _, row in df.iterrows():
|
| 84 |
+
conversation = str(row.get("conversation", "") or "")
|
| 85 |
+
trigger = str(row.get("trigger", "") or "neutral")
|
| 86 |
+
move = str(row.get("move", "") or "neutral")
|
| 87 |
+
reply = str(row.get("male_reply", "") or "").strip()
|
| 88 |
+
|
| 89 |
+
if not conversation or not reply:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
prompt = build_instruction(conversation, trigger, move, persona)
|
| 93 |
+
training_data.append(
|
| 94 |
+
{
|
| 95 |
+
"instruction": prompt,
|
| 96 |
+
"input": "",
|
| 97 |
+
"output": reply,
|
| 98 |
+
}
|
| 99 |
+
)
|
| 100 |
+
return training_data
|
| 101 |
+
|
| 102 |
+
# Fallback: dùng dữ liệu gốc (kém lý tưởng hơn)
|
| 103 |
trigger_cols = [col for col in df.columns if col.startswith("trigger_")]
|
| 104 |
move_cols = [col for col in df.columns if col.startswith("move_")]
|
| 105 |
+
|
| 106 |
+
for _, row in df.iterrows():
|
| 107 |
+
user_text = str(row["user_text"]) if pd.notna(row.get("user_text")) else ""
|
| 108 |
+
partner_text = str(row["partner_text"]) if pd.notna(row.get("partner_text")) else ""
|
| 109 |
+
|
|
|
|
| 110 |
if not partner_text or partner_text.strip() == "_":
|
| 111 |
continue
|
| 112 |
+
|
|
|
|
| 113 |
active_triggers = get_active_labels(row, trigger_cols)
|
| 114 |
active_moves = get_active_labels(row, move_cols)
|
| 115 |
+
|
|
|
|
| 116 |
trigger = active_triggers[0] if active_triggers[0] != "none" else "neutral"
|
| 117 |
move = active_moves[0] if active_moves[0] != "none" else "neutral"
|
| 118 |
+
|
|
|
|
| 119 |
if use_history and conversation_history:
|
| 120 |
history_str = "\n".join(conversation_history)
|
| 121 |
if user_text and user_text.strip() != "_":
|
|
|
|
| 128 |
conversation = f"Male: {user_text} ||| Female: {partner_text}"
|
| 129 |
else:
|
| 130 |
conversation = f"Female: {partner_text}"
|
| 131 |
+
|
| 132 |
prompt = build_instruction(conversation, trigger, move, persona)
|
|
|
|
| 133 |
response = partner_text.strip()
|
| 134 |
+
|
| 135 |
+
training_data.append(
|
| 136 |
+
{
|
| 137 |
+
"instruction": prompt,
|
| 138 |
+
"input": "",
|
| 139 |
+
"output": response,
|
| 140 |
+
}
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
if user_text and user_text.strip() != "_":
|
| 144 |
conversation_history.append(f"Male: {user_text}")
|
| 145 |
if partner_text and partner_text.strip() != "_":
|
| 146 |
conversation_history.append(f"Female: {partner_text}")
|
| 147 |
+
|
|
|
|
| 148 |
max_history = 4
|
| 149 |
if len(conversation_history) > max_history:
|
| 150 |
conversation_history = conversation_history[-max_history:]
|
| 151 |
+
|
| 152 |
return training_data
|
| 153 |
|
| 154 |
|