Spaces:
Running
Running
File size: 5,112 Bytes
1ea875f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# 文件路径: evaluation/utils.py
"""
评估模块公共工具函数和常量
将重复的逻辑抽取到这里,保持代码 DRY (Don't Repeat Yourself)
"""
from typing import List
# ============================================================================
# 闲聊/无效 Query 检测
# ============================================================================
CHATTY_PATTERNS: List[str] = [
# 中文闲聊
"你好", "您好", "嗨", "在吗", "在不在", "谢谢", "多谢", "再见", "拜拜",
"什么是", "你是谁", "你叫什么", "帮帮我", "教教我",
# 英文闲聊
"hello", "hi", "hey", "thanks", "thank you", "bye", "goodbye",
"what is", "who are you", "help me", "can you",
# 单词/简短
"test", "测试", "ok", "yes", "no",
]
# 代码语言指示符
CODE_INDICATORS: List[str] = [
# Python
"def ", "class ", "import ", "from ",
# JavaScript/TypeScript
"function ", "const ", "let ", "var ",
# Java/C#
"public ", "private ", "void ",
# Go
"func ", "package ",
# 通用
"```", # Markdown 代码块
]
def is_chatty_query(query: str, min_length: int = 5) -> bool:
"""
检测是否为闲聊/无效 query
Args:
query: 用户查询
min_length: 最小有效长度,低于此值视为无效
Returns:
True 如果是闲聊/无效查询
"""
if not query:
return True
query_lower = query.lower().strip()
# 长度检查
if len(query_lower) < min_length:
return True
# 模式匹配
for pattern in CHATTY_PATTERNS:
if query_lower == pattern or query_lower.startswith(pattern + " "):
return True
return False
def has_code_indicators(text: str) -> bool:
"""
检查文本是否包含代码指示符
Args:
text: 要检查的文本
Returns:
True 如果包含代码特征
"""
if not text:
return False
for indicator in CODE_INDICATORS:
if indicator in text:
return True
return False
# ============================================================================
# 文件操作工具
# ============================================================================
def append_jsonl(filepath: str, data: dict) -> None:
"""
追加一行 JSON 到 JSONL 文件
Args:
filepath: 文件路径
data: 要追加的数据字典
"""
import json
with open(filepath, 'a', encoding='utf-8') as f:
f.write(json.dumps(data, ensure_ascii=False) + '\n')
def read_jsonl(filepath: str) -> list:
"""
读取 JSONL 文件
Args:
filepath: 文件路径
Returns:
数据列表
"""
import json
import os
if not os.path.exists(filepath):
return []
results = []
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
try:
results.append(json.loads(line))
except json.JSONDecodeError:
continue
return results
def safe_truncate(text: str, max_length: int, suffix: str = "\n... [truncated]") -> str:
"""
安全截断文本
Args:
text: 原始文本
max_length: 最大长度
suffix: 截断后缀
Returns:
截断后的文本
"""
if not text or len(text) <= max_length:
return text
return text[:max_length] + suffix
def smart_truncate(text: str, max_length: int, keep_ratio: float = 0.7) -> str:
"""
智能截断:保留开头大部分 + 结尾小部分,适合代码上下文
Args:
text: 原始文本
max_length: 最大长度
keep_ratio: 开头保留比例(默认 70% 开头,30% 结尾)
Returns:
截断后的文本,保留首尾关键内容
"""
if not text or len(text) <= max_length:
return text
separator = "\n\n... [中间内容已省略] ...\n\n"
available = max_length - len(separator)
if available <= 0:
return text[:max_length]
head_len = int(available * keep_ratio)
tail_len = available - head_len
return text[:head_len] + separator + text[-tail_len:]
# ============================================================================
# SFT 数据长度配置
# ============================================================================
class SFTLengthConfig:
"""SFT 训练数据长度配置"""
# Context 限制(检索到的代码上下文)
MAX_CONTEXT_CHARS = 2500 # 最大字符数 (~800 tokens)
# Answer 限制(模型生成的回答)
MAX_ANSWER_CHARS = 3000 # 最大字符数 (~1000 tokens)
# Query 限制
MAX_QUERY_CHARS = 500 # 最大字符数
# 总体限制
MAX_TOTAL_CHARS = 6000 # 总字符数上限 (~2000 tokens)
# Token 估算(中英文混合,保守估计)
CHARS_PER_TOKEN = 3 # 平均每 token 的字符数
|