File size: 6,111 Bytes
2147ce8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | import json
from pathlib import Path
from .text_quality import clean_answer_text, clean_context_text, clean_training_text
TEXT_EXTENSIONS = {".txt", ".md", ".text"}
STRUCTURED_EXTENSIONS = {".jsonl", ".json"}
def _default_record_weight(record_type: str) -> int:
if record_type == "dialogue_turn":
return 2
if record_type == "instruction_answer":
return 2
if record_type == "preference_chosen":
return 3
if record_type == "preference_rejected":
return 0
return 1
def _record_repeat_count(record: object) -> int:
if not isinstance(record, dict):
return 1
if bool(record.get("drop")):
return 0
raw_weight = record.get("weight")
if raw_weight is not None:
try:
numeric = int(round(float(raw_weight)))
except (TypeError, ValueError):
numeric = 1
return max(0, min(8, numeric))
return _default_record_weight(str(record.get("record_type", "")))
def _coerce_text_record(record: object) -> str:
if isinstance(record, str):
return clean_training_text(record.strip())
if isinstance(record, dict):
if "text" in record:
return clean_training_text(str(record["text"]).strip())
if "content" in record:
return clean_training_text(str(record["content"]).strip())
if "context" in record and "answer" in record:
context = clean_context_text(str(record["context"]).strip())
answer = clean_answer_text(str(record["answer"]).strip())
if context and answer:
return f"<reason> {context} <answer> {answer}"
return ""
def _coerce_prompt_record(record: object) -> dict[str, object] | None:
if isinstance(record, str):
prompt = record.strip()
return {"prompt": prompt, "tags": []} if prompt else None
if isinstance(record, dict):
raw_prompt = record.get("prompt", record.get("context", ""))
prompt = clean_context_text(str(raw_prompt).strip())
if not prompt:
return None
raw_tags = record.get("tags", [])
tags = [str(tag) for tag in raw_tags] if isinstance(raw_tags, list) else []
normalized = dict(record)
normalized["prompt"] = prompt
normalized["tags"] = tags
return normalized
return None
def load_text_corpus(source: str | Path) -> str:
path = Path(source)
if path.is_dir():
parts = [
load_text_corpus(child)
for child in sorted(path.rglob("*"))
if child.is_file() and child.suffix.lower() in TEXT_EXTENSIONS | STRUCTURED_EXTENSIONS
]
return "\n".join(part for part in parts if part.strip())
suffix = path.suffix.lower()
if suffix in TEXT_EXTENSIONS:
return path.read_text(encoding="utf-8")
if suffix == ".jsonl":
lines = []
for line in path.read_text(encoding="utf-8").splitlines():
if not line.strip():
continue
record = json.loads(line)
text = _coerce_text_record(record)
if text:
lines.extend([text] * _record_repeat_count(record))
return "\n".join(lines)
if suffix == ".json":
payload = json.loads(path.read_text(encoding="utf-8"))
if isinstance(payload, list):
parts: list[str] = []
for item in payload:
text = _coerce_text_record(item)
if text:
parts.extend([text] * _record_repeat_count(item))
return "\n".join(parts)
if isinstance(payload, dict):
if "texts" in payload and isinstance(payload["texts"], list):
parts: list[str] = []
for item in payload["texts"]:
text = _coerce_text_record(item)
if text:
parts.extend([text] * _record_repeat_count(item))
return "\n".join(parts)
if "records" in payload and isinstance(payload["records"], list):
parts: list[str] = []
for item in payload["records"]:
text = _coerce_text_record(item)
if text:
parts.extend([text] * _record_repeat_count(item))
return "\n".join(parts)
text = _coerce_text_record(payload)
if text:
return "\n".join([text] * _record_repeat_count(payload))
raise ValueError(f"Unsupported corpus source: {path}")
def load_prompt_suite(source: str | Path) -> list[dict[str, object]]:
path = Path(source)
suffix = path.suffix.lower()
prompts: list[dict[str, object]] = []
if suffix in TEXT_EXTENSIONS:
for line in path.read_text(encoding="utf-8").splitlines():
record = _coerce_prompt_record(line)
if record is not None:
prompts.append(record)
return prompts
if suffix == ".jsonl":
for line in path.read_text(encoding="utf-8").splitlines():
if not line.strip():
continue
record = _coerce_prompt_record(json.loads(line))
if record is not None:
prompts.append(record)
return prompts
if suffix == ".json":
payload = json.loads(path.read_text(encoding="utf-8"))
if isinstance(payload, list):
for item in payload:
record = _coerce_prompt_record(item)
if record is not None:
prompts.append(record)
return prompts
if isinstance(payload, dict):
if "prompts" in payload and isinstance(payload["prompts"], list):
for item in payload["prompts"]:
record = _coerce_prompt_record(item)
if record is not None:
prompts.append(record)
return prompts
record = _coerce_prompt_record(payload)
if record is not None:
return [record]
raise ValueError(f"Unsupported prompt suite: {path}")
|