Paul commited on
Commit
1a6e95f
·
1 Parent(s): bbce197
Files changed (2) hide show
  1. build_reply_dataset.py +140 -0
  2. finetune_model.py +52 -27
build_reply_dataset.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility script to build a clean reply-generation dataset for Models 1 & 2.
3
+
4
+ Goal:
5
+ - Tạo ra dataset mà output luôn là **câu trả lời từ phía Nam**,
6
+ đã được chuẩn hoá theo prompt wingman (anh/em, 1 câu, ≤25 từ).
7
+
8
+ Ý tưởng:
9
+ - Đọc `new_data_selected.csv` gốc (user_text, partner_text, trigger_*, move_*)
10
+ - Tính trigger/move chính (như khi fine-tune trigger detector)
11
+ - Gom hội thoại: `Male: user_text ||| Female: partner_text`
12
+ - Gọi wingman prompt-based (`ReplySuggestionService`) để sinh `male_reply`
13
+ - Ghi ra CSV mới: `conversation,trigger,move,male_reply`
14
+
15
+ Usage (local hoặc trong Spaces terminal):
16
+ python build_reply_dataset.py \
17
+ --data_path new_data_selected.csv \
18
+ --output_path reply_training_data.csv
19
+
20
+ Lưu ý:
21
+ - Script dùng Hugging Face Inference API, cần HF_TOKEN có quyền call model (như trong reply_service.py).
22
+ """
23
+
24
+ import argparse
25
+ import os
26
+ from typing import List
27
+
28
+ import pandas as pd
29
+
30
+ from reply_service import ReplySuggestionService
31
+
32
+
33
+ def _detect_columns(df: pd.DataFrame, prefix: str) -> List[str]:
34
+ cols = [col for col in df.columns if col.startswith(prefix)]
35
+ if not cols:
36
+ raise ValueError(f"No columns found with prefix '{prefix}'")
37
+ return cols
38
+
39
+
40
+ def _get_active_label(row: pd.Series, cols: List[str], prefix: str) -> str:
41
+ """Lấy nhãn đầu tiên có giá trị 1, nếu không có thì trả về 'neutral'."""
42
+ for col in cols:
43
+ if float(row.get(col, 0)) == 1.0:
44
+ return col.replace(prefix, "")
45
+ return "neutral"
46
+
47
+
48
+ def _build_conversation(user_text: str, partner_text: str) -> str:
49
+ user = (user_text or "").strip()
50
+ partner = (partner_text or "").strip()
51
+ if not user and not partner:
52
+ return ""
53
+ if user:
54
+ return f"Male: {user} ||| Female: {partner}"
55
+ return f"Female: {partner}"
56
+
57
+
58
+ def main():
59
+ parser = argparse.ArgumentParser(description="Build reply-generation dataset from new_data_selected.csv")
60
+ parser.add_argument(
61
+ "--data_path",
62
+ type=str,
63
+ default="new_data_selected.csv",
64
+ help="Đường dẫn đến CSV gốc (new_data_selected.csv)",
65
+ )
66
+ parser.add_argument(
67
+ "--output_path",
68
+ type=str,
69
+ default="reply_training_data.csv",
70
+ help="Đường dẫn file CSV output chứa male_reply",
71
+ )
72
+ parser.add_argument(
73
+ "--max_rows",
74
+ type=int,
75
+ default=-1,
76
+ help="Giới hạn số dòng xử lý (debug). -1 = dùng toàn bộ",
77
+ )
78
+ args = parser.parse_args()
79
+
80
+ print(f"[BUILD_DATASET] Loading dataset from {args.data_path}")
81
+ df = pd.read_csv(args.data_path)
82
+
83
+ if args.max_rows > 0:
84
+ df = df.head(args.max_rows)
85
+ print(f"[BUILD_DATASET] Using first {len(df)} rows for reply synthesis")
86
+
87
+ trigger_cols = _detect_columns(df, "trigger_")
88
+ move_cols = _detect_columns(df, "move_")
89
+
90
+ # Khởi tạo wingman prompt-based
91
+ print("[BUILD_DATASET] Initializing ReplySuggestionService (Inference API)...")
92
+ reply_service = ReplySuggestionService()
93
+
94
+ out_rows = []
95
+
96
+ for idx, row in df.iterrows():
97
+ user_text = str(row.get("user_text", "") or "")
98
+ partner_text = str(row.get("partner_text", "") or "")
99
+ conversation = _build_conversation(user_text, partner_text)
100
+
101
+ if not conversation:
102
+ continue
103
+
104
+ trigger = _get_active_label(row, trigger_cols, "trigger_")
105
+ move = _get_active_label(row, move_cols, "move_")
106
+
107
+ try:
108
+ male_reply = reply_service.suggest_reply(
109
+ male=user_text,
110
+ female=partner_text,
111
+ tone=move,
112
+ intent=trigger,
113
+ )
114
+ except Exception as exc:
115
+ print(f"[BUILD_DATASET] Row {idx}: reply generation failed: {exc}")
116
+ continue
117
+
118
+ out_rows.append(
119
+ {
120
+ "conversation": conversation,
121
+ "user_text": user_text,
122
+ "partner_text": partner_text,
123
+ "trigger": trigger,
124
+ "move": move,
125
+ "male_reply": male_reply,
126
+ }
127
+ )
128
+
129
+ if (idx + 1) % 50 == 0:
130
+ print(f"[BUILD_DATASET] Processed {idx + 1} rows...")
131
+
132
+ out_df = pd.DataFrame(out_rows)
133
+ out_df.to_csv(args.output_path, index=False)
134
+ print(f"[BUILD_DATASET] Saved {len(out_df)} rows to {args.output_path}")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()
139
+
140
+
finetune_model.py CHANGED
@@ -66,30 +66,56 @@ def build_instruction(conversation: str, trigger: str, move: str, persona: str)
66
 
67
 
68
  def prepare_training_data(df, use_history=True, persona="default"):
69
- """Prepare data for fine-tuning"""
 
 
 
 
 
 
 
70
  training_data = []
71
  conversation_history = []
72
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  trigger_cols = [col for col in df.columns if col.startswith("trigger_")]
74
  move_cols = [col for col in df.columns if col.startswith("move_")]
75
-
76
- for idx, row in df.iterrows():
77
- user_text = str(row['user_text']) if pd.notna(row['user_text']) else ""
78
- partner_text = str(row['partner_text']) if pd.notna(row['partner_text']) else ""
79
-
80
- # Skip rows with invalid data
81
  if not partner_text or partner_text.strip() == "_":
82
  continue
83
-
84
- # Get active triggers and moves
85
  active_triggers = get_active_labels(row, trigger_cols)
86
  active_moves = get_active_labels(row, move_cols)
87
-
88
- # Format: Use only the first active trigger/move (highest priority)
89
  trigger = active_triggers[0] if active_triggers[0] != "none" else "neutral"
90
  move = active_moves[0] if active_moves[0] != "none" else "neutral"
91
-
92
- # Build conversation context
93
  if use_history and conversation_history:
94
  history_str = "\n".join(conversation_history)
95
  if user_text and user_text.strip() != "_":
@@ -102,28 +128,27 @@ def prepare_training_data(df, use_history=True, persona="default"):
102
  conversation = f"Male: {user_text} ||| Female: {partner_text}"
103
  else:
104
  conversation = f"Female: {partner_text}"
105
-
106
  prompt = build_instruction(conversation, trigger, move, persona)
107
-
108
  response = partner_text.strip()
109
-
110
- training_data.append({
111
- "instruction": prompt,
112
- "input": "",
113
- "output": response
114
- })
115
-
116
- # Update conversation history
 
117
  if user_text and user_text.strip() != "_":
118
  conversation_history.append(f"Male: {user_text}")
119
  if partner_text and partner_text.strip() != "_":
120
  conversation_history.append(f"Female: {partner_text}")
121
-
122
- # Limit history length
123
  max_history = 4
124
  if len(conversation_history) > max_history:
125
  conversation_history = conversation_history[-max_history:]
126
-
127
  return training_data
128
 
129
 
 
66
 
67
 
68
  def prepare_training_data(df, use_history=True, persona="default"):
69
+ """
70
+ Prepare data for fine-tuning.
71
+
72
+ Nếu dataset đã có cột `male_reply` (build bởi build_reply_dataset.py) thì dùng:
73
+ conversation, trigger, move, male_reply
74
+ Làm ground-truth chuẩn cho reply từ phía Nam.
75
+ Nếu không, fallback về logic cũ dựa trên user_text / partner_text (ít lý tưởng hơn).
76
+ """
77
  training_data = []
78
  conversation_history = []
79
+
80
+ has_clean_reply = {"conversation", "trigger", "move", "male_reply"}.issubset(set(df.columns))
81
+
82
+ if has_clean_reply:
83
+ for _, row in df.iterrows():
84
+ conversation = str(row.get("conversation", "") or "")
85
+ trigger = str(row.get("trigger", "") or "neutral")
86
+ move = str(row.get("move", "") or "neutral")
87
+ reply = str(row.get("male_reply", "") or "").strip()
88
+
89
+ if not conversation or not reply:
90
+ continue
91
+
92
+ prompt = build_instruction(conversation, trigger, move, persona)
93
+ training_data.append(
94
+ {
95
+ "instruction": prompt,
96
+ "input": "",
97
+ "output": reply,
98
+ }
99
+ )
100
+ return training_data
101
+
102
+ # Fallback: dùng dữ liệu gốc (kém lý tưởng hơn)
103
  trigger_cols = [col for col in df.columns if col.startswith("trigger_")]
104
  move_cols = [col for col in df.columns if col.startswith("move_")]
105
+
106
+ for _, row in df.iterrows():
107
+ user_text = str(row["user_text"]) if pd.notna(row.get("user_text")) else ""
108
+ partner_text = str(row["partner_text"]) if pd.notna(row.get("partner_text")) else ""
109
+
 
110
  if not partner_text or partner_text.strip() == "_":
111
  continue
112
+
 
113
  active_triggers = get_active_labels(row, trigger_cols)
114
  active_moves = get_active_labels(row, move_cols)
115
+
 
116
  trigger = active_triggers[0] if active_triggers[0] != "none" else "neutral"
117
  move = active_moves[0] if active_moves[0] != "none" else "neutral"
118
+
 
119
  if use_history and conversation_history:
120
  history_str = "\n".join(conversation_history)
121
  if user_text and user_text.strip() != "_":
 
128
  conversation = f"Male: {user_text} ||| Female: {partner_text}"
129
  else:
130
  conversation = f"Female: {partner_text}"
131
+
132
  prompt = build_instruction(conversation, trigger, move, persona)
 
133
  response = partner_text.strip()
134
+
135
+ training_data.append(
136
+ {
137
+ "instruction": prompt,
138
+ "input": "",
139
+ "output": response,
140
+ }
141
+ )
142
+
143
  if user_text and user_text.strip() != "_":
144
  conversation_history.append(f"Male: {user_text}")
145
  if partner_text and partner_text.strip() != "_":
146
  conversation_history.append(f"Female: {partner_text}")
147
+
 
148
  max_history = 4
149
  if len(conversation_history) > max_history:
150
  conversation_history = conversation_history[-max_history:]
151
+
152
  return training_data
153
 
154