Spaces:
Running
Running
File size: 44,569 Bytes
1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a d8c5622 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a 0701598 1a93d7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 属性选择器脚本
# 此脚本用于分析用户配置文件并选择最适合的属性
import json
import os
import sys
import logging
from typing import Dict, List, Any, Optional, Tuple
import argparse
from pathlib import Path
import random
import numpy as np
import pickle
from tqdm import tqdm
import time
ATTRIBUTE_SELECTION_CACHE = None
# 导入项目配置
from config import client, GPT_MODEL, parse_json_response
# 定义get_completion函数
def get_completion(messages, model=GPT_MODEL, temperature=0.7):
"""使用OpenAI API生成文本完成"""
try:
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Error calling OpenAI API: {e}")
return None
# 导入based_data模块中的函数
from based_data import (
generate_age_info,
generate_gender,
generate_career_info,
generate_location,
generate_personal_values,
generate_life_attitude,
generate_personal_story,
generate_interests_and_hobbies
)
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 属性数据集路径
ATTRIBUTES_PATH = "/home/zhou/deeppersona/generate_user_profile_test/data/large_attributes.json" # 属性数据集路径
# 向量数据库路径
EMBEDDINGS_PATH = "/home/zhou/deeppersona/generate_user_profile_test/data/attribute_embeddings.pkl" # 属性嵌入向量路径
# 默认模型来自配置
DEFAULT_MODEL = GPT_MODEL
# 向量搜索参数
NEAR_NEIGHBOR_COUNT = 7 # 近邻数量
MID_NEIGHBOR_COUNT = 2 # 中距离邻居数量
FAR_NEIGHBOR_COUNT = 1 # 远距离邻居数量
DIVERSITY_THRESHOLD = 0.7 # 多样性阈值(余弦相似度)
class AttributeSelector:
"""
属性选择器类
用于分析用户配置文件并使用GPT-4o选择适当的属性
"""
def __init__(self, model: str = DEFAULT_MODEL, user_profile: Dict = None):
"""
初始化属性选择器
参数:
model: 要使用的GPT模型
user_profile: 用户配置文件数据(可选)
"""
self.model = model
# OpenAI客户端已在config.py中设置
# 加载属性数据
self.attributes = self._load_json(ATTRIBUTES_PATH) # 加载属性数据
# 设置用户配置文件
self.user_profile = user_profile
# 验证数据
self._validate_data()
# 加载向量数据库
self.embeddings_data = self._load_embeddings()
# 初始化属性路径和向量映射
self.path_to_embedding = {}
self.paths = []
self.embeddings = []
if self.embeddings_data:
self.paths = self.embeddings_data.get('paths', [])
self.embeddings = self.embeddings_data.get('embeddings', [])
# 创建路径到向量的映射
for i, path in enumerate(self.paths):
if i < len(self.embeddings):
self.path_to_embedding[path] = self.embeddings[i]
logger.info(f"已加载 {len(self.paths)} 条属性路径和对应的向量嵌入")
logger.info(f"已加载属性,包含 {len(self.attributes.keys())} 个顶级类别")
def _load_json(self, file_path: str) -> Dict:
"""从文件加载JSON数据"""
try:
return json.loads(Path(file_path).read_text(encoding='utf-8'))
except Exception as e:
logger.error(f"从 {file_path} 加载JSON时出错: {e}")
raise
def _load_embeddings(self) -> Dict:
"""加载属性嵌入向量数据库"""
try:
if not os.path.exists(EMBEDDINGS_PATH):
logger.warning(f"嵌入向量文件 {EMBEDDINGS_PATH} 不存在")
return None
with open(EMBEDDINGS_PATH, 'rb') as f:
embeddings_data = pickle.load(f)
# 检查数据结构并标准化键名
paths_key = 'attribute_paths' if 'attribute_paths' in embeddings_data else 'paths'
embeddings_key = 'embeddings'
# 获取路径和嵌入向量
paths = embeddings_data.get(paths_key, [])
embeddings = embeddings_data.get(embeddings_key, [])
# 如果数据无效,返回空值
if not isinstance(embeddings_data, dict) or not paths or not isinstance(embeddings, np.ndarray):
logger.warning("嵌入向量数据格式无效")
return None
# 标准化返回的数据字典
standardized_data = {
'paths': paths,
'embeddings': embeddings
}
logger.info(f"从 {EMBEDDINGS_PATH} 加载了 {len(paths)} 条属性嵌入向量")
return standardized_data
except Exception as e:
logger.error(f"加载嵌入向量时出错: {e}")
return None
def _validate_data(self):
"""验证加载的数据并转换格式(如需要)"""
# 检查属性格式
if isinstance(self.attributes, dict):
# 如果是嵌套字典格式(如large_attributes.json),转换为路径格式
if "paths" not in self.attributes:
logger.info("正在将属性从嵌套字典格式转换为路径格式")
self.attributes = {"paths": self._flatten_attributes(self.attributes)}
else:
raise ValueError("无效的属性格式:不是字典")
def _create_profile_embedding(self, profile: Dict) -> np.ndarray:
"""
为用户配置文件创建嵌入向量
参数:
profile: 用户配置文件字典
返回:
配置文件的嵌入向量
"""
try:
# 提取配置文件摘要
profile_summary = self._extract_profile_summary(profile)
# 使用OpenAI API生成嵌入向量
response = client.embeddings.create(
model="text-embedding-ada-002",
input=profile_summary
)
# 提取嵌入向量
embedding = np.array(response.data[0].embedding)
logger.info(f"成功为用户配置文件创建了嵌入向量")
return embedding
except Exception as e:
logger.error(f"创建配置文件嵌入向量时出错: {e}")
return None
def _compute_cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
"""
计算两个向量之间的余弦相似度
参数:
vec1: 第一个向量
vec2: 第二个向量
返回:
余弦相似度值,范围为[-1, 1]
"""
# 处理空向量或None值
if vec1 is None or vec2 is None or len(vec1) == 0 or len(vec2) == 0:
return 0.0
# 确保向量维度匹配
if len(vec1) != len(vec2):
logger.warning(f"向量维度不匹配: {len(vec1)} vs {len(vec2)}")
return 0.0
# 计算向量范数
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
# 处理零向量
if norm1 == 0 or norm2 == 0:
return 0.0
# 计算余弦相似度
return np.dot(vec1, vec2) / (norm1 * norm2)
def get_profile(self) -> Dict:
"""获取用户配置文件"""
return self.user_profile
def _check_if_career_needed(self, profile: Dict, profile_summary: str) -> bool:
"""
使用GPT判断是否需要"Career and Work Identity"属性
参数:
profile: 用户配置文件字典
profile_summary: 用户配置文件摘要
返回:
布尔值,表示是否需要职业相关属性
"""
# 创建专门用于判断职业属性必要性的提示词
prompt = f"""
# Career Attribute Necessity Assessment
## User Profile Summary:
{profile_summary}
## Task Description:
You are an AI assistant tasked with determining whether the "Career and Work Identity" attribute category is relevant and necessary for this specific individual based on their profile.
Consider the following factors:
1. The person's age and life stage
2. Current employment or educational status
3. How central career and work identity appears to be to this person's life
4. Whether career information would be essential to create a complete picture of this person
## Output Format:
Provide a JSON response with these keys:
1. "is_career_needed": Boolean (true/false) indicating whether Career and Work Identity attributes are needed
2. "reasoning": String explaining your rationale
Your response must be valid JSON that can be parsed by Python's json.loads() function.
"""
try:
# 调用GPT API,使用config.py中的get_completion函数
messages = [
{"role": "system", "content": "You are an AI assistant that analyzes user profiles and determines whether career attributes are necessary. You always respond with valid JSON."},
{"role": "user", "content": prompt}
]
content = get_completion(messages, model=self.model, temperature=0.3)
# Use the project's JSON parsing function
result = parse_json_response(content, {"is_career_needed": True, "reasoning": "默认需要职业属性"})
# 确保结果包含必要的键
if "is_career_needed" not in result:
logger.warning("Missing 'is_career_needed' key in GPT response, defaulting to True")
result["is_career_needed"] = True
logger.info(f"Career attributes needed: {result['is_career_needed']}, Reason: {result.get('reasoning', 'No reasoning provided')}")
return result["is_career_needed"]
except Exception as e:
logger.error(f"Error calling GPT API for career assessment: {e}")
# 出错时默认需要职业属性
return True
def _extract_profile_summary(self, profile: Dict) -> str:
"""
提取配置文件摘要,用于GPT分析
参数:
profile: 用户配置文件字典
返回:
配置文件摘要字符串
"""
summary_parts = []
# Add age information
if "age_info" in profile:
age_info = profile["age_info"]
if "age" in age_info:
summary_parts.append(f"Age: {age_info['age']}")
if "age_group" in age_info:
summary_parts.append(f"Age Group: {age_info['age_group']}")
# Add gender information
if "gender" in profile:
gender = profile["gender"]
# Convert Chinese characters to English if needed
if gender == "男":
gender = "Male"
elif gender == "女":
gender = "Female"
summary_parts.append(f"Gender: {gender}")
# Add location information
if "location" in profile:
location = profile["location"]
location_str = f"{location.get('city', '')}, {location.get('country', '')}"
summary_parts.append(f"Location: {location_str}")
# Add career information
if "career_info" in profile and "status" in profile["career_info"]:
summary_parts.append(f"Career Status: {profile['career_info']['status']}")
# Add personal values
if "personal_values" in profile and "values_orientation" in profile["personal_values"]:
summary_parts.append(f"Values: {profile['personal_values']['values_orientation']}")
# Add life attitude
if "life_attitude" in profile:
attitude = profile["life_attitude"]
if "attitude" in attitude:
summary_parts.append(f"Life Attitude: {attitude['attitude']}")
if "attitude_details" in attitude:
summary_parts.append(f"Attitude Details: {attitude['attitude_details']}")
if "coping_mechanism" in attitude:
summary_parts.append(f"Coping Mechanism: {attitude['coping_mechanism']}")
# Add interests
if "interests" in profile and "interests" in profile["interests"]:
interests = profile["interests"]["interests"]
if interests and isinstance(interests, list):
summary_parts.append(f"Interests: {', '.join(interests)}")
# Add personal story (life_story) - 现在包含这部分内容
if "personal_story" in profile and "personal_story" in profile["personal_story"]:
life_story = profile["personal_story"]["personal_story"]
if life_story:
summary_parts.append(f"Life Story: {life_story}")
# Add summary if available
if "summary" in profile:
summary_parts.append(f"Profile Summary: {profile['summary']}")
# 现在包含完整的based_data信息,包括life_story
return "\n".join(summary_parts)
def _flatten_attributes(self, attributes: Dict, prefix: str = "") -> List[str]:
"""
将属性字典扁平化为属性路径列表
参数:
attributes: 属性字典或子字典
prefix: 当前路径前缀
"""
result = []
def _flatten(attr_dict, curr_prefix):
for k, v in attr_dict.items():
path = f"{curr_prefix}.{k}" if curr_prefix else k
# 只有叶子节点(空字典)才添加到结果中
if isinstance(v, dict):
if not v: # 空字典,这是叶子节点
result.append(path)
else: # 非空字典,继续递归
_flatten(v, path)
else: # 非字典值,直接添加
result.append(path)
_flatten(attributes, prefix)
return result
def _get_attribute_categories(self) -> List[str]:
"""获取顶级属性类别"""
# 从路径中提取唯一的顶级类别
return sorted({path.split('.')[0] for path in self.attributes["paths"]})
def _format_attributes_tree(self, attributes_dict: Dict, prefix: str = "", depth: int = 0) -> List[str]:
"""
将属性字典格式化为文本格式的树结构
参数:
attributes_dict: 属性字典
prefix: 当前路径前缀
depth: 树中的当前深度
返回:
表示树的格式化行列表
"""
lines = []
for key, value in sorted(attributes_dict.items()):
current_path = f"{prefix}.{key}" if prefix else key
indent = " " * depth
if depth > 0:
prefix_char = "│ " * (depth - 1) + "- "
else:
prefix_char = "- "
lines.append(f"{indent}{prefix_char}{key}")
if isinstance(value, dict) and value:
sub_lines = self._format_attributes_tree(value, current_path, depth + 1)
lines.extend(sub_lines)
return lines
def analyze_profile_for_attributes(self, profile: Dict) -> Dict[str, List[str]]:
"""
分析用户配置文件并确定应该生成哪些属性
参数:
profile: 用户配置文件字典
返回:
包含'recommended'和'not_recommended'属性类别的字典
"""
# 提取配置文件摘要
profile_summary = self._extract_profile_summary(profile)
# 获取属性类别
categories = self._get_attribute_categories()
# 第一步:使用GPT判断是否需要"Career and Work Identity"
career_needed = self._check_if_career_needed(profile, profile_summary)
# 第二步:保留所有一级属性(除了可能被排除的"Career and Work Identity")
recommended_categories = []
not_recommended_categories = []
for category in categories:
if category == "Career and Work Identity" and not career_needed:
not_recommended_categories.append(category)
logger.info(f"根据分析,不需要职业和工作身份属性")
else:
recommended_categories.append(category)
# 生成结果
result = {
"recommended": recommended_categories,
"not_recommended": not_recommended_categories,
"reasoning": f"根据用户背景分析,{'需要' if career_needed else '不需要'}职业和工作身份属性。保留所有其他一级属性,从每个属性中选择最符合用户背景的特征。"
}
return result
def process_profile(self) -> Dict:
"""
处理用户配置文件并生成属性推荐
返回:
包含配置文件和属性推荐的字典
"""
# 如果未提供用户配置文件,生成一个
if not self.user_profile:
self.user_profile = generate_user_profile()
# 提取配置文件摘要
profile_summary = self._extract_profile_summary(self.user_profile)
# 分析配置文件并获取属性推荐
attribute_recommendations = self.analyze_profile_for_attributes(self.user_profile)
# 返回结果
return {
"profile_summary": profile_summary,
"attribute_recommendations": attribute_recommendations
}
def _get_nested_attributes(self, category: str) -> Dict:
"""
通过从路径重建获取特定类别的嵌套属性
参数:
category: 类别名称
"""
result = {}
# 筛选以该类别开头的路径并重建嵌套结构
for path in [p for p in self.attributes["paths"] if p.startswith(f"{category}.")]:
parts = path.split('.')[1:] # 跳过第一部分(类别)
current = result
for i, part in enumerate(parts):
if i == len(parts) - 1:
current[part] = {} # 叶节点
else:
current.setdefault(part, {}) # 如果键不存在则创建
current = current[part]
return result
def select_top_attributes(self, path_list, target_count=200):
"""
从路径列表中选择最重要和最具代表性的顶级属性。
参数:
path_list: 属性路径列表
target_count: 要选择的目标属性数量
返回:
选定的属性路径列表
"""
if not path_list:
return []
# If we already have fewer paths than target, return all of them
if len(path_list) <= target_count:
return path_list
# Extract unique top-level categories
categories = {}
for path in path_list:
parts = path.split('.')
if len(parts) > 0:
category = parts[0]
if category not in categories:
categories[category] = []
categories[category].append(path)
# Calculate how many attributes to select from each category
# proportional to their representation in the original list
total_paths = len(path_list)
category_counts = {}
for category, paths in categories.items():
# Calculate proportional count but ensure at least 1 attribute per category
count = max(1, int(len(paths) / total_paths * target_count))
category_counts[category] = count
# Adjust counts to match target_count as closely as possible
total_selected = sum(category_counts.values())
if total_selected < target_count:
# Distribute remaining slots to largest categories
remaining = target_count - total_selected
sorted_categories = sorted(categories.keys(),
key=lambda c: len(categories[c]),
reverse=True)
for i in range(remaining):
if i < len(sorted_categories):
category_counts[sorted_categories[i]] += 1
elif total_selected > target_count:
# Remove from largest categories until we hit target
excess = total_selected - target_count
sorted_categories = sorted(categories.keys(),
key=lambda c: category_counts[c],
reverse=True)
for i in range(excess):
if i < len(sorted_categories) and category_counts[sorted_categories[i]] > 1:
category_counts[sorted_categories[i]] -= 1
# Select paths from each category
selected_paths = []
for category, count in category_counts.items():
category_paths = categories[category]
# Prioritize paths with fewer segments (more general attributes)
# and paths that represent common attributes across domains
scored_paths = []
for path in category_paths:
parts = path.split('.')
# Score is based on path depth (shorter is better) and presence of common terms
common_terms = ['general', 'common', 'basic', 'core', 'essential', 'fundamental', 'key']
common_term_bonus = any(term in path.lower() for term in common_terms)
score = (10 - len(parts)) + (5 if common_term_bonus else 0)
scored_paths.append((path, score))
# Sort by score (higher is better) and select top paths
sorted_paths = sorted(scored_paths, key=lambda x: x[1], reverse=True)
selected_category_paths = [p[0] for p in sorted_paths[:count]]
selected_paths.extend(selected_category_paths)
return selected_paths
# 注意:删除了_select_best_matching_attributes方法,因为它已经被新的随机选择和GPT筛选方法替代
def _find_interesting_neighbors(self, profile_embedding: np.ndarray, category_paths: List[str],
target_count: int = 300) -> List[str]:
"""
实现"有趣的邻居"向量搜索方案
参数:
profile_embedding: 用户配置文件的嵌入向量
category_paths: 特定类别的属性路径列表
target_count: 要选择的目标属性数量
返回:
选定的属性路径列表或空列表(如果出现读取问题)
"""
# 如果没有向量数据库或路径为空,返回空列表
if not self.embeddings_data or not category_paths:
logger.warning("没有可用的向量数据库或类别路径为空,返回空列表")
return []
# 过滤出在向量数据库中有嵌入向量的路径
valid_paths = [path for path in category_paths if path in self.path_to_embedding]
if not valid_paths:
logger.warning("没有有效的属性路径匹配向量数据库,返回空列表")
return []
# 计算每个路径与配置文件的相似度
path_similarities = []
for path in valid_paths:
embedding = self.path_to_embedding[path]
similarity = self._compute_cosine_similarity(profile_embedding, embedding)
path_similarities.append((path, similarity))
# 按相似度排序
path_similarities.sort(key=lambda x: x[1], reverse=True)
# 计算要选择的数量
total_paths = len(path_similarities)
# 如果路径数量少于目标数量,返回所有路径
if total_paths <= target_count:
return [p[0] for p in path_similarities]
selected_paths = []
used_indices = set()
# 按5:3:2比例分配近邻、中距离、远距离邻居
total_ratio = 5 + 3 + 2 # 总比例 = 10
# 1. 选择近邻(相似度最高的属性)- 50% (5/10)
near_count = min(int(target_count * 5 / total_ratio), total_paths // 3)
near_indices = list(range(near_count))
for i in random.sample(near_indices, min(near_count, len(near_indices))):
if i not in used_indices:
selected_paths.append(path_similarities[i][0])
used_indices.add(i)
# 2. 选择中距离邻居 - 30% (3/10)
mid_start = total_paths // 3
mid_end = 2 * total_paths // 3
mid_count = min(int(target_count * 3 / total_ratio), (mid_end - mid_start))
mid_indices = list(range(mid_start, mid_end))
if mid_indices:
for i in random.sample(mid_indices, min(mid_count, len(mid_indices))):
if i not in used_indices:
selected_paths.append(path_similarities[i][0])
used_indices.add(i)
# 3. 选择远距离邻居(相似度最低的属性)- 20% (2/10)
far_start = 2 * total_paths // 3
far_count = min(int(target_count * 2 / total_ratio), (total_paths - far_start))
far_indices = list(range(far_start, total_paths))
if far_indices:
for i in random.sample(far_indices, min(far_count, len(far_indices))):
if i not in used_indices:
selected_paths.append(path_similarities[i][0])
used_indices.add(i)
# 如果还需要更多属性来达到目标数量,从未使用的索引中随机选择
remaining_count = target_count - len(selected_paths)
if remaining_count > 0:
remaining_indices = [i for i in range(total_paths) if i not in used_indices]
if remaining_indices:
additional_count = min(remaining_count, len(remaining_indices))
for i in random.sample(remaining_indices, additional_count):
selected_paths.append(path_similarities[i][0])
logger.info(f"使用向量搜索选择了 {len(selected_paths)} 条属性(近邻: {near_count}, 中距离: {mid_count}, 远距离: {far_count})")
return selected_paths
def get_top_attributes(self, result: Dict, target_count: int = 200) -> List[str]:
"""
获取属性列表
参数:
result: 来自analyze_profile_for_attributes的结果
target_count: 目标属性数量
返回:
属性路径列表
"""
try:
# 收集所有推荐的类别
all_recommended = set()
if "recommended" in result:
all_recommended.update(result["recommended"])
# 提取所有路径
all_paths = []
category_paths = {}
if "paths" in self.attributes:
# 如果我们使用带有路径的新格式
for path in self.attributes["paths"]:
# 只包括来自推荐类别的路径
category = path.split('.')[0] if '.' in path else path
if category in all_recommended:
all_paths.append(path)
# 按类别分组路径
if category not in category_paths:
category_paths[category] = []
category_paths[category].append(path)
else:
# 如果我们使用旧的嵌套格式,将其扁平化
for category in all_recommended:
paths = self._flatten_attributes(self._get_nested_attributes(category), category)
all_paths.extend(paths)
category_paths[category] = paths
# 使用传入的target_count参数,不再随机选择
# 如果有向量数据库,使用向量搜索
if self.embeddings_data and self.user_profile:
# 创建用户配置文件的嵌入向量
profile_embedding = self._create_profile_embedding(self.user_profile)
if profile_embedding is not None:
# 为每个类别选择属性
final_paths = []
for category, paths in category_paths.items():
# 根据类别大小按比例分配目标数量
category_ratio = len(paths) / len(all_paths)
category_target = max(3, int(target_count * category_ratio))
# 使用向量搜索选择该类别的属性
category_selected = self._find_interesting_neighbors(
profile_embedding, paths, category_target
)
final_paths.extend(category_selected)
# 如果选择的属性超过目标数量,随机减少
if len(final_paths) > target_count:
final_paths = random.sample(final_paths, target_count)
logger.info(f"使用向量搜索从 {len(all_paths)} 条属性中选择了 {len(final_paths)} 条属性")
return final_paths
# 如果没有向量数据库或向量搜索失败,直接返回空列表
logger.warning(f"没有可用的向量数据库或向量搜索失败,返回空列表")
return []
except Exception as e:
logger.error(f"Error generating attribute list: {e}")
# 如果出错,直接返回空列表
logger.error(f"属性选择过程出错,返回空列表")
return []
# 注意:删除了_random_select_from_top_categories方法,因为现在直接在get_top_attributes中随机选择属性
def convert_web_profile_to_selector_format(web_profile: Dict) -> Dict:
"""
将web_api_bridge.py保存的配置文件格式转换为AttributeSelector期望的格式
核心原则:用户输入了什么就用什么,没输入的才生成
web_profile格式:
{
"age": int,
"gender": str,
"Occupations": [str],
"location": {"city": str, "country": str},
"personal_values": {"values_orientation": str},
"life_attitude": str or dict,
"personal_story": {"personal_story": str},
"interests": {"interests": [str]}
}
selector格式:
{
"age_info": {"age": int, "age_group": str},
"gender": str,
"location": {"city": str, "country": str},
"career_info": {"status": str},
"personal_values": {"values_orientation": str},
"life_attitude": {"attitude": str, "attitude_details": str, "coping_mechanism": str},
"personal_story": {"personal_story": str},
"interests": {"interests": [str]}
}
"""
# 提取年龄信息 - 如果用户提供了就用,否则生成
age = web_profile.get("age")
if age:
# 根据年龄确定年龄组
if age <= 6:
age_group = "toddler"
elif age <= 12:
age_group = "child"
elif age <= 19:
age_group = "adolescent"
elif age <= 29:
age_group = "young_adult"
elif age <= 45:
age_group = "adult"
elif age <= 65:
age_group = "middle_aged"
else:
age_group = "senior"
age_info = {"age": age, "age_group": age_group}
print(f"✓ Using user-provided age: {age}")
else:
age_info = generate_age_info()
print(f"✗ Age not provided, generated: {age_info['age']}")
# 提取性别 - 如果用户提供了就用,否则生成
gender = web_profile.get("gender")
if gender:
print(f"✓ Using user-provided gender: {gender}")
else:
gender = generate_gender()
print(f"✗ Gender not provided, generated: {gender}")
# 提取位置 - 如果用户提供了就用,否则生成
location = web_profile.get("location")
if location and location.get("city") and location.get("country"):
print(f"✓ Using user-provided location: {location.get('city')}, {location.get('country')}")
else:
location = generate_location()
print(f"✗ Location not provided, generated: {location.get('city')}, {location.get('country')}")
# 提取职业信息 - 如果用户提供了就用,否则生成
occupations = web_profile.get("Occupations", [])
if occupations and len(occupations) > 0 and occupations[0]:
career_info = {"status": occupations[0]}
print(f"✓ Using user-provided occupation: {occupations[0]}")
else:
career_info = generate_career_info(age_info["age"])
print(f"✗ Occupation not provided, generated: {career_info['status']}")
# 提取个人价值观 - 如果用户提供了就用,否则生成
personal_values = web_profile.get("personal_values", {})
values_orientation = personal_values.get("values_orientation", "")
if values_orientation and values_orientation.strip():
# 用户提供了价值观,直接使用
personal_values = {"values_orientation": values_orientation.strip()}
print(f"✓ Using user-provided personal values: {values_orientation[:50]}...")
else:
# 用户没提供,基于已有信息生成
personal_values = generate_personal_values(
age=age_info["age"],
gender=gender,
occupation=career_info["status"],
location=location
)
print(f"✗ Personal values not provided, generated based on user inputs")
# 提取生活态度 - 如果用户提供了就用,否则生成
life_attitude_data = web_profile.get("life_attitude")
if life_attitude_data:
if isinstance(life_attitude_data, str) and life_attitude_data.strip():
# 用户提供了字符串格式的生活态度
life_attitude = {
"attitude": life_attitude_data.strip(),
"attitude_details": "",
"coping_mechanism": ""
}
print(f"✓ Using user-provided life attitude: {life_attitude_data[:50]}...")
elif isinstance(life_attitude_data, dict) and life_attitude_data.get("attitude"):
# 用户提供了字典格式的生活态度
life_attitude = life_attitude_data
print(f"✓ Using user-provided life attitude (dict format)")
else:
# 数据格式不对或为空,生成新的
life_attitude = generate_life_attitude(
age=age_info["age"],
gender=gender,
occupation=career_info["status"],
location=location,
values_orientation=personal_values.get("values_orientation", "")
)
print(f"✗ Life attitude not provided, generated based on user inputs")
else:
# 用户没提供,基于已有信息生成
life_attitude = generate_life_attitude(
age=age_info["age"],
gender=gender,
occupation=career_info["status"],
location=location,
values_orientation=personal_values.get("values_orientation", "")
)
print(f"✗ Life attitude not provided, generated based on user inputs")
# 提取个人故事 - 如果用户提供了就用,否则生成
personal_story_data = web_profile.get("personal_story", {})
if isinstance(personal_story_data, dict) and personal_story_data.get("personal_story"):
story_text = personal_story_data.get("personal_story", "")
if story_text and story_text.strip():
# 用户提供了故事,直接使用
personal_story = {"personal_story": story_text.strip()}
print(f"✓ Using user-provided life story: {story_text[:50]}...")
else:
# 故事为空,生成新的
personal_story = generate_personal_story(
age=age_info["age"],
gender=gender,
occupation=career_info["status"],
location=location,
values_orientation=personal_values.get("values_orientation", ""),
life_attitude=life_attitude
)
print(f"✗ Life story not provided, generated based on user inputs")
elif isinstance(personal_story_data, str) and personal_story_data.strip():
# 用户提供了字符串格式的故事
personal_story = {"personal_story": personal_story_data.strip()}
print(f"✓ Using user-provided life story (string format)")
else:
# 用户没提供,基于已有信息生成
personal_story = generate_personal_story(
age=age_info["age"],
gender=gender,
occupation=career_info["status"],
location=location,
values_orientation=personal_values.get("values_orientation", ""),
life_attitude=life_attitude
)
print(f"✗ Life story not provided, generated based on user inputs")
# 提取兴趣爱好 - 如果用户提供了就用,否则生成
interests = web_profile.get("interests", {})
interests_list = interests.get("interests", [])
if interests_list and len(interests_list) > 0 and any(i.strip() for i in interests_list):
# 用户提供了兴趣爱好,直接使用
interests = {"interests": [i.strip() for i in interests_list if i.strip()]}
print(f"✓ Using user-provided interests: {', '.join(interests['interests'])}")
else:
# 用户没提供,基于故事生成
interests = generate_interests_and_hobbies(personal_story)
print(f"✗ Interests not provided, generated based on life story")
# 返回转换后的格式
return {
"age_info": age_info,
"gender": gender,
"location": location,
"career_info": career_info,
"personal_values": personal_values,
"life_attitude": life_attitude,
"personal_story": personal_story,
"interests": interests
}
def generate_user_profile() -> Dict:
"""生成用户基础信息配置文件"""
# 生成并存储直接函数返回值
age_info = generate_age_info()
gender = generate_gender()
location = generate_location()
career_info = generate_career_info(age_info["age"])
# 生成个人价值观
values = generate_personal_values(
age=age_info["age"],
gender=gender,
occupation=career_info["status"],
location=location
)
# 生成生活态度
life_attitude = generate_life_attitude(
age=age_info["age"],
gender=gender,
occupation=career_info["status"],
location=location,
values_orientation=values.get("values_orientation", "")
)
# 生成个人故事
personal_story = generate_personal_story(
age=age_info["age"],
gender=gender,
occupation=career_info["status"],
location=location,
values_orientation=values.get("values_orientation", ""),
life_attitude=life_attitude
)
# 生成兴趣爱好
interests = generate_interests_and_hobbies(personal_story)
# 存储函数返回值
user_profile = {
"age_info": age_info,
"gender": gender,
"location": location,
"career_info": career_info,
"personal_values": values,
"life_attitude": life_attitude,
"personal_story": personal_story,
"interests": interests
}
return user_profile
def get_selected_attributes(user_profile=None, attribute_count=200):
global ATTRIBUTE_SELECTION_CACHE
# 注释掉缓存机制,确保每次都重新选择属性
# if ATTRIBUTE_SELECTION_CACHE is not None:
# return ATTRIBUTE_SELECTION_CACHE
try:
# 如果没有提供用户配置文件,生成一个
if user_profile is None:
user_profile = generate_user_profile()
# 创建选择器并传入用户配置文件
selector = AttributeSelector(user_profile=user_profile)
# 处理配置文件
result = selector.process_profile()
# 获取属性推荐
attribute_recommendations = result.get("attribute_recommendations", {})
# 获取属性列表,使用传入的attribute_count参数
top_paths = selector.get_top_attributes(attribute_recommendations, target_count=attribute_count)
# 返回属性列表
ATTRIBUTE_SELECTION_CACHE = top_paths
return top_paths
except Exception as e:
logger.error(f"Error getting selected attributes: {e}")
return []
def build_nested_dict(paths: List[str]) -> Dict:
result = {}
for path in paths:
parts = path.split('.')
current = result
for part in parts:
if part not in current:
current[part] = {}
current = current[part]
return result
def save_results(user_profile: Dict, selected_paths: List[str], output_dir: str = '/home/zhou/persona/generate_user_profile/output') -> None:
"""
保存用户配置文件和选定的属性路径到文件
参数:
user_profile: 用户配置文件
selected_paths: 选定的属性路径 (列表形式)
output_dir: 输出目录(默认为 '/home/zhou/persona/generate_user_profile/output')
"""
try:
from pathlib import Path
import json
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 保存用户配置文件
profile_path = output_path / "user_profile.json"
with open(profile_path, 'w', encoding='utf-8') as f:
json.dump(user_profile, f, ensure_ascii=False, indent=2)
logger.info(f"用户配置文件已保存到 {profile_path}")
# 将 selected_paths 转换为嵌套字典结构
nested_selected_paths = build_nested_dict(selected_paths)
# 保存属性路径(嵌套字典格式)
paths_path = output_path / "selected_paths.json"
with open(paths_path, 'w', encoding='utf-8') as f:
json.dump(nested_selected_paths, f, ensure_ascii=False, indent=2)
logger.info(f"属性路径已保存到 {paths_path}")
except Exception as e:
logger.error(f"保存结果时出错: {e}")
raise
# 示例:在生成用户基本信息和属性列表的函数中自动调用保存(请根据实际情况将此调用添加到合适位置)
user_profile = generate_user_profile()
selected_paths = get_selected_attributes(user_profile)
save_results(user_profile, selected_paths)
# 此文件只供其他文件导入使用 |