Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # 属性选择器脚本 | |
| # 此脚本用于分析用户配置文件并选择最适合的属性 | |
| import json | |
| import os | |
| import sys | |
| import logging | |
| from typing import Dict, List, Any, Optional, Tuple | |
| import argparse | |
| from pathlib import Path | |
| import random | |
| import numpy as np | |
| import pickle | |
| from tqdm import tqdm | |
| import time | |
| ATTRIBUTE_SELECTION_CACHE = None | |
| # 导入项目配置 | |
| from config import client, GPT_MODEL, parse_json_response | |
| # 定义get_completion函数 | |
| def get_completion(messages, model=GPT_MODEL, temperature=0.7): | |
| """使用OpenAI API生成文本完成""" | |
| try: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| logger.error(f"Error calling OpenAI API: {e}") | |
| return None | |
| # 导入based_data模块中的函数 | |
| from based_data import ( | |
| generate_age_info, | |
| generate_gender, | |
| generate_career_info, | |
| generate_location, | |
| generate_personal_values, | |
| generate_life_attitude, | |
| generate_personal_story, | |
| generate_interests_and_hobbies | |
| ) | |
| # 配置日志 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # 属性数据集路径 | |
| ATTRIBUTES_PATH = "/home/zhou/deeppersona/generate_user_profile_test/data/large_attributes.json" # 属性数据集路径 | |
| # 向量数据库路径 | |
| EMBEDDINGS_PATH = "/home/zhou/deeppersona/generate_user_profile_test/data/attribute_embeddings.pkl" # 属性嵌入向量路径 | |
| # 默认模型来自配置 | |
| DEFAULT_MODEL = GPT_MODEL | |
| # 向量搜索参数 | |
| NEAR_NEIGHBOR_COUNT = 7 # 近邻数量 | |
| MID_NEIGHBOR_COUNT = 2 # 中距离邻居数量 | |
| FAR_NEIGHBOR_COUNT = 1 # 远距离邻居数量 | |
| DIVERSITY_THRESHOLD = 0.7 # 多样性阈值(余弦相似度) | |
| class AttributeSelector: | |
| """ | |
| 属性选择器类 | |
| 用于分析用户配置文件并使用GPT-4o选择适当的属性 | |
| """ | |
| def __init__(self, model: str = DEFAULT_MODEL, user_profile: Dict = None): | |
| """ | |
| 初始化属性选择器 | |
| 参数: | |
| model: 要使用的GPT模型 | |
| user_profile: 用户配置文件数据(可选) | |
| """ | |
| self.model = model | |
| # OpenAI客户端已在config.py中设置 | |
| # 加载属性数据 | |
| self.attributes = self._load_json(ATTRIBUTES_PATH) # 加载属性数据 | |
| # 设置用户配置文件 | |
| self.user_profile = user_profile | |
| # 验证数据 | |
| self._validate_data() | |
| # 加载向量数据库 | |
| self.embeddings_data = self._load_embeddings() | |
| # 初始化属性路径和向量映射 | |
| self.path_to_embedding = {} | |
| self.paths = [] | |
| self.embeddings = [] | |
| if self.embeddings_data: | |
| self.paths = self.embeddings_data.get('paths', []) | |
| self.embeddings = self.embeddings_data.get('embeddings', []) | |
| # 创建路径到向量的映射 | |
| for i, path in enumerate(self.paths): | |
| if i < len(self.embeddings): | |
| self.path_to_embedding[path] = self.embeddings[i] | |
| logger.info(f"已加载 {len(self.paths)} 条属性路径和对应的向量嵌入") | |
| logger.info(f"已加载属性,包含 {len(self.attributes.keys())} 个顶级类别") | |
| def _load_json(self, file_path: str) -> Dict: | |
| """从文件加载JSON数据""" | |
| try: | |
| return json.loads(Path(file_path).read_text(encoding='utf-8')) | |
| except Exception as e: | |
| logger.error(f"从 {file_path} 加载JSON时出错: {e}") | |
| raise | |
| def _load_embeddings(self) -> Dict: | |
| """加载属性嵌入向量数据库""" | |
| try: | |
| if not os.path.exists(EMBEDDINGS_PATH): | |
| logger.warning(f"嵌入向量文件 {EMBEDDINGS_PATH} 不存在") | |
| return None | |
| with open(EMBEDDINGS_PATH, 'rb') as f: | |
| embeddings_data = pickle.load(f) | |
| # 检查数据结构并标准化键名 | |
| paths_key = 'attribute_paths' if 'attribute_paths' in embeddings_data else 'paths' | |
| embeddings_key = 'embeddings' | |
| # 获取路径和嵌入向量 | |
| paths = embeddings_data.get(paths_key, []) | |
| embeddings = embeddings_data.get(embeddings_key, []) | |
| # 如果数据无效,返回空值 | |
| if not isinstance(embeddings_data, dict) or not paths or not isinstance(embeddings, np.ndarray): | |
| logger.warning("嵌入向量数据格式无效") | |
| return None | |
| # 标准化返回的数据字典 | |
| standardized_data = { | |
| 'paths': paths, | |
| 'embeddings': embeddings | |
| } | |
| logger.info(f"从 {EMBEDDINGS_PATH} 加载了 {len(paths)} 条属性嵌入向量") | |
| return standardized_data | |
| except Exception as e: | |
| logger.error(f"加载嵌入向量时出错: {e}") | |
| return None | |
| def _validate_data(self): | |
| """验证加载的数据并转换格式(如需要)""" | |
| # 检查属性格式 | |
| if isinstance(self.attributes, dict): | |
| # 如果是嵌套字典格式(如large_attributes.json),转换为路径格式 | |
| if "paths" not in self.attributes: | |
| logger.info("正在将属性从嵌套字典格式转换为路径格式") | |
| self.attributes = {"paths": self._flatten_attributes(self.attributes)} | |
| else: | |
| raise ValueError("无效的属性格式:不是字典") | |
| def _create_profile_embedding(self, profile: Dict) -> np.ndarray: | |
| """ | |
| 为用户配置文件创建嵌入向量 | |
| 参数: | |
| profile: 用户配置文件字典 | |
| 返回: | |
| 配置文件的嵌入向量 | |
| """ | |
| try: | |
| # 提取配置文件摘要 | |
| profile_summary = self._extract_profile_summary(profile) | |
| # 使用OpenAI API生成嵌入向量 | |
| response = client.embeddings.create( | |
| model="text-embedding-ada-002", | |
| input=profile_summary | |
| ) | |
| # 提取嵌入向量 | |
| embedding = np.array(response.data[0].embedding) | |
| logger.info(f"成功为用户配置文件创建了嵌入向量") | |
| return embedding | |
| except Exception as e: | |
| logger.error(f"创建配置文件嵌入向量时出错: {e}") | |
| return None | |
| def _compute_cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: | |
| """ | |
| 计算两个向量之间的余弦相似度 | |
| 参数: | |
| vec1: 第一个向量 | |
| vec2: 第二个向量 | |
| 返回: | |
| 余弦相似度值,范围为[-1, 1] | |
| """ | |
| # 处理空向量或None值 | |
| if vec1 is None or vec2 is None or len(vec1) == 0 or len(vec2) == 0: | |
| return 0.0 | |
| # 确保向量维度匹配 | |
| if len(vec1) != len(vec2): | |
| logger.warning(f"向量维度不匹配: {len(vec1)} vs {len(vec2)}") | |
| return 0.0 | |
| # 计算向量范数 | |
| norm1 = np.linalg.norm(vec1) | |
| norm2 = np.linalg.norm(vec2) | |
| # 处理零向量 | |
| if norm1 == 0 or norm2 == 0: | |
| return 0.0 | |
| # 计算余弦相似度 | |
| return np.dot(vec1, vec2) / (norm1 * norm2) | |
| def get_profile(self) -> Dict: | |
| """获取用户配置文件""" | |
| return self.user_profile | |
| def _check_if_career_needed(self, profile: Dict, profile_summary: str) -> bool: | |
| """ | |
| 使用GPT判断是否需要"Career and Work Identity"属性 | |
| 参数: | |
| profile: 用户配置文件字典 | |
| profile_summary: 用户配置文件摘要 | |
| 返回: | |
| 布尔值,表示是否需要职业相关属性 | |
| """ | |
| # 创建专门用于判断职业属性必要性的提示词 | |
| prompt = f""" | |
| # Career Attribute Necessity Assessment | |
| ## User Profile Summary: | |
| {profile_summary} | |
| ## Task Description: | |
| You are an AI assistant tasked with determining whether the "Career and Work Identity" attribute category is relevant and necessary for this specific individual based on their profile. | |
| Consider the following factors: | |
| 1. The person's age and life stage | |
| 2. Current employment or educational status | |
| 3. How central career and work identity appears to be to this person's life | |
| 4. Whether career information would be essential to create a complete picture of this person | |
| ## Output Format: | |
| Provide a JSON response with these keys: | |
| 1. "is_career_needed": Boolean (true/false) indicating whether Career and Work Identity attributes are needed | |
| 2. "reasoning": String explaining your rationale | |
| Your response must be valid JSON that can be parsed by Python's json.loads() function. | |
| """ | |
| try: | |
| # 调用GPT API,使用config.py中的get_completion函数 | |
| messages = [ | |
| {"role": "system", "content": "You are an AI assistant that analyzes user profiles and determines whether career attributes are necessary. You always respond with valid JSON."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| content = get_completion(messages, model=self.model, temperature=0.3) | |
| # Use the project's JSON parsing function | |
| result = parse_json_response(content, {"is_career_needed": True, "reasoning": "默认需要职业属性"}) | |
| # 确保结果包含必要的键 | |
| if "is_career_needed" not in result: | |
| logger.warning("Missing 'is_career_needed' key in GPT response, defaulting to True") | |
| result["is_career_needed"] = True | |
| logger.info(f"Career attributes needed: {result['is_career_needed']}, Reason: {result.get('reasoning', 'No reasoning provided')}") | |
| return result["is_career_needed"] | |
| except Exception as e: | |
| logger.error(f"Error calling GPT API for career assessment: {e}") | |
| # 出错时默认需要职业属性 | |
| return True | |
| def _extract_profile_summary(self, profile: Dict) -> str: | |
| """ | |
| 提取配置文件摘要,用于GPT分析 | |
| 参数: | |
| profile: 用户配置文件字典 | |
| 返回: | |
| 配置文件摘要字符串 | |
| """ | |
| summary_parts = [] | |
| # Add age information | |
| if "age_info" in profile: | |
| age_info = profile["age_info"] | |
| if "age" in age_info: | |
| summary_parts.append(f"Age: {age_info['age']}") | |
| if "age_group" in age_info: | |
| summary_parts.append(f"Age Group: {age_info['age_group']}") | |
| # Add gender information | |
| if "gender" in profile: | |
| gender = profile["gender"] | |
| # Convert Chinese characters to English if needed | |
| if gender == "男": | |
| gender = "Male" | |
| elif gender == "女": | |
| gender = "Female" | |
| summary_parts.append(f"Gender: {gender}") | |
| # Add location information | |
| if "location" in profile: | |
| location = profile["location"] | |
| location_str = f"{location.get('city', '')}, {location.get('country', '')}" | |
| summary_parts.append(f"Location: {location_str}") | |
| # Add career information | |
| if "career_info" in profile and "status" in profile["career_info"]: | |
| summary_parts.append(f"Career Status: {profile['career_info']['status']}") | |
| # Add personal values | |
| if "personal_values" in profile and "values_orientation" in profile["personal_values"]: | |
| summary_parts.append(f"Values: {profile['personal_values']['values_orientation']}") | |
| # Add life attitude | |
| if "life_attitude" in profile: | |
| attitude = profile["life_attitude"] | |
| if "attitude" in attitude: | |
| summary_parts.append(f"Life Attitude: {attitude['attitude']}") | |
| if "attitude_details" in attitude: | |
| summary_parts.append(f"Attitude Details: {attitude['attitude_details']}") | |
| if "coping_mechanism" in attitude: | |
| summary_parts.append(f"Coping Mechanism: {attitude['coping_mechanism']}") | |
| # Add interests | |
| if "interests" in profile and "interests" in profile["interests"]: | |
| interests = profile["interests"]["interests"] | |
| if interests and isinstance(interests, list): | |
| summary_parts.append(f"Interests: {', '.join(interests)}") | |
| # Add personal story (life_story) - 现在包含这部分内容 | |
| if "personal_story" in profile and "personal_story" in profile["personal_story"]: | |
| life_story = profile["personal_story"]["personal_story"] | |
| if life_story: | |
| summary_parts.append(f"Life Story: {life_story}") | |
| # Add summary if available | |
| if "summary" in profile: | |
| summary_parts.append(f"Profile Summary: {profile['summary']}") | |
| # 现在包含完整的based_data信息,包括life_story | |
| return "\n".join(summary_parts) | |
| def _flatten_attributes(self, attributes: Dict, prefix: str = "") -> List[str]: | |
| """ | |
| 将属性字典扁平化为属性路径列表 | |
| 参数: | |
| attributes: 属性字典或子字典 | |
| prefix: 当前路径前缀 | |
| """ | |
| result = [] | |
| def _flatten(attr_dict, curr_prefix): | |
| for k, v in attr_dict.items(): | |
| path = f"{curr_prefix}.{k}" if curr_prefix else k | |
| # 只有叶子节点(空字典)才添加到结果中 | |
| if isinstance(v, dict): | |
| if not v: # 空字典,这是叶子节点 | |
| result.append(path) | |
| else: # 非空字典,继续递归 | |
| _flatten(v, path) | |
| else: # 非字典值,直接添加 | |
| result.append(path) | |
| _flatten(attributes, prefix) | |
| return result | |
| def _get_attribute_categories(self) -> List[str]: | |
| """获取顶级属性类别""" | |
| # 从路径中提取唯一的顶级类别 | |
| return sorted({path.split('.')[0] for path in self.attributes["paths"]}) | |
| def _format_attributes_tree(self, attributes_dict: Dict, prefix: str = "", depth: int = 0) -> List[str]: | |
| """ | |
| 将属性字典格式化为文本格式的树结构 | |
| 参数: | |
| attributes_dict: 属性字典 | |
| prefix: 当前路径前缀 | |
| depth: 树中的当前深度 | |
| 返回: | |
| 表示树的格式化行列表 | |
| """ | |
| lines = [] | |
| for key, value in sorted(attributes_dict.items()): | |
| current_path = f"{prefix}.{key}" if prefix else key | |
| indent = " " * depth | |
| if depth > 0: | |
| prefix_char = "│ " * (depth - 1) + "- " | |
| else: | |
| prefix_char = "- " | |
| lines.append(f"{indent}{prefix_char}{key}") | |
| if isinstance(value, dict) and value: | |
| sub_lines = self._format_attributes_tree(value, current_path, depth + 1) | |
| lines.extend(sub_lines) | |
| return lines | |
| def analyze_profile_for_attributes(self, profile: Dict) -> Dict[str, List[str]]: | |
| """ | |
| 分析用户配置文件并确定应该生成哪些属性 | |
| 参数: | |
| profile: 用户配置文件字典 | |
| 返回: | |
| 包含'recommended'和'not_recommended'属性类别的字典 | |
| """ | |
| # 提取配置文件摘要 | |
| profile_summary = self._extract_profile_summary(profile) | |
| # 获取属性类别 | |
| categories = self._get_attribute_categories() | |
| # 第一步:使用GPT判断是否需要"Career and Work Identity" | |
| career_needed = self._check_if_career_needed(profile, profile_summary) | |
| # 第二步:保留所有一级属性(除了可能被排除的"Career and Work Identity") | |
| recommended_categories = [] | |
| not_recommended_categories = [] | |
| for category in categories: | |
| if category == "Career and Work Identity" and not career_needed: | |
| not_recommended_categories.append(category) | |
| logger.info(f"根据分析,不需要职业和工作身份属性") | |
| else: | |
| recommended_categories.append(category) | |
| # 生成结果 | |
| result = { | |
| "recommended": recommended_categories, | |
| "not_recommended": not_recommended_categories, | |
| "reasoning": f"根据用户背景分析,{'需要' if career_needed else '不需要'}职业和工作身份属性。保留所有其他一级属性,从每个属性中选择最符合用户背景的特征。" | |
| } | |
| return result | |
| def process_profile(self) -> Dict: | |
| """ | |
| 处理用户配置文件并生成属性推荐 | |
| 返回: | |
| 包含配置文件和属性推荐的字典 | |
| """ | |
| # 如果未提供用户配置文件,生成一个 | |
| if not self.user_profile: | |
| self.user_profile = generate_user_profile() | |
| # 提取配置文件摘要 | |
| profile_summary = self._extract_profile_summary(self.user_profile) | |
| # 分析配置文件并获取属性推荐 | |
| attribute_recommendations = self.analyze_profile_for_attributes(self.user_profile) | |
| # 返回结果 | |
| return { | |
| "profile_summary": profile_summary, | |
| "attribute_recommendations": attribute_recommendations | |
| } | |
| def _get_nested_attributes(self, category: str) -> Dict: | |
| """ | |
| 通过从路径重建获取特定类别的嵌套属性 | |
| 参数: | |
| category: 类别名称 | |
| """ | |
| result = {} | |
| # 筛选以该类别开头的路径并重建嵌套结构 | |
| for path in [p for p in self.attributes["paths"] if p.startswith(f"{category}.")]: | |
| parts = path.split('.')[1:] # 跳过第一部分(类别) | |
| current = result | |
| for i, part in enumerate(parts): | |
| if i == len(parts) - 1: | |
| current[part] = {} # 叶节点 | |
| else: | |
| current.setdefault(part, {}) # 如果键不存在则创建 | |
| current = current[part] | |
| return result | |
| def select_top_attributes(self, path_list, target_count=200): | |
| """ | |
| 从路径列表中选择最重要和最具代表性的顶级属性。 | |
| 参数: | |
| path_list: 属性路径列表 | |
| target_count: 要选择的目标属性数量 | |
| 返回: | |
| 选定的属性路径列表 | |
| """ | |
| if not path_list: | |
| return [] | |
| # If we already have fewer paths than target, return all of them | |
| if len(path_list) <= target_count: | |
| return path_list | |
| # Extract unique top-level categories | |
| categories = {} | |
| for path in path_list: | |
| parts = path.split('.') | |
| if len(parts) > 0: | |
| category = parts[0] | |
| if category not in categories: | |
| categories[category] = [] | |
| categories[category].append(path) | |
| # Calculate how many attributes to select from each category | |
| # proportional to their representation in the original list | |
| total_paths = len(path_list) | |
| category_counts = {} | |
| for category, paths in categories.items(): | |
| # Calculate proportional count but ensure at least 1 attribute per category | |
| count = max(1, int(len(paths) / total_paths * target_count)) | |
| category_counts[category] = count | |
| # Adjust counts to match target_count as closely as possible | |
| total_selected = sum(category_counts.values()) | |
| if total_selected < target_count: | |
| # Distribute remaining slots to largest categories | |
| remaining = target_count - total_selected | |
| sorted_categories = sorted(categories.keys(), | |
| key=lambda c: len(categories[c]), | |
| reverse=True) | |
| for i in range(remaining): | |
| if i < len(sorted_categories): | |
| category_counts[sorted_categories[i]] += 1 | |
| elif total_selected > target_count: | |
| # Remove from largest categories until we hit target | |
| excess = total_selected - target_count | |
| sorted_categories = sorted(categories.keys(), | |
| key=lambda c: category_counts[c], | |
| reverse=True) | |
| for i in range(excess): | |
| if i < len(sorted_categories) and category_counts[sorted_categories[i]] > 1: | |
| category_counts[sorted_categories[i]] -= 1 | |
| # Select paths from each category | |
| selected_paths = [] | |
| for category, count in category_counts.items(): | |
| category_paths = categories[category] | |
| # Prioritize paths with fewer segments (more general attributes) | |
| # and paths that represent common attributes across domains | |
| scored_paths = [] | |
| for path in category_paths: | |
| parts = path.split('.') | |
| # Score is based on path depth (shorter is better) and presence of common terms | |
| common_terms = ['general', 'common', 'basic', 'core', 'essential', 'fundamental', 'key'] | |
| common_term_bonus = any(term in path.lower() for term in common_terms) | |
| score = (10 - len(parts)) + (5 if common_term_bonus else 0) | |
| scored_paths.append((path, score)) | |
| # Sort by score (higher is better) and select top paths | |
| sorted_paths = sorted(scored_paths, key=lambda x: x[1], reverse=True) | |
| selected_category_paths = [p[0] for p in sorted_paths[:count]] | |
| selected_paths.extend(selected_category_paths) | |
| return selected_paths | |
| # 注意:删除了_select_best_matching_attributes方法,因为它已经被新的随机选择和GPT筛选方法替代 | |
| def _find_interesting_neighbors(self, profile_embedding: np.ndarray, category_paths: List[str], | |
| target_count: int = 300) -> List[str]: | |
| """ | |
| 实现"有趣的邻居"向量搜索方案 | |
| 参数: | |
| profile_embedding: 用户配置文件的嵌入向量 | |
| category_paths: 特定类别的属性路径列表 | |
| target_count: 要选择的目标属性数量 | |
| 返回: | |
| 选定的属性路径列表或空列表(如果出现读取问题) | |
| """ | |
| # 如果没有向量数据库或路径为空,返回空列表 | |
| if not self.embeddings_data or not category_paths: | |
| logger.warning("没有可用的向量数据库或类别路径为空,返回空列表") | |
| return [] | |
| # 过滤出在向量数据库中有嵌入向量的路径 | |
| valid_paths = [path for path in category_paths if path in self.path_to_embedding] | |
| if not valid_paths: | |
| logger.warning("没有有效的属性路径匹配向量数据库,返回空列表") | |
| return [] | |
| # 计算每个路径与配置文件的相似度 | |
| path_similarities = [] | |
| for path in valid_paths: | |
| embedding = self.path_to_embedding[path] | |
| similarity = self._compute_cosine_similarity(profile_embedding, embedding) | |
| path_similarities.append((path, similarity)) | |
| # 按相似度排序 | |
| path_similarities.sort(key=lambda x: x[1], reverse=True) | |
| # 计算要选择的数量 | |
| total_paths = len(path_similarities) | |
| # 如果路径数量少于目标数量,返回所有路径 | |
| if total_paths <= target_count: | |
| return [p[0] for p in path_similarities] | |
| selected_paths = [] | |
| used_indices = set() | |
| # 按5:3:2比例分配近邻、中距离、远距离邻居 | |
| total_ratio = 5 + 3 + 2 # 总比例 = 10 | |
| # 1. 选择近邻(相似度最高的属性)- 50% (5/10) | |
| near_count = min(int(target_count * 5 / total_ratio), total_paths // 3) | |
| near_indices = list(range(near_count)) | |
| for i in random.sample(near_indices, min(near_count, len(near_indices))): | |
| if i not in used_indices: | |
| selected_paths.append(path_similarities[i][0]) | |
| used_indices.add(i) | |
| # 2. 选择中距离邻居 - 30% (3/10) | |
| mid_start = total_paths // 3 | |
| mid_end = 2 * total_paths // 3 | |
| mid_count = min(int(target_count * 3 / total_ratio), (mid_end - mid_start)) | |
| mid_indices = list(range(mid_start, mid_end)) | |
| if mid_indices: | |
| for i in random.sample(mid_indices, min(mid_count, len(mid_indices))): | |
| if i not in used_indices: | |
| selected_paths.append(path_similarities[i][0]) | |
| used_indices.add(i) | |
| # 3. 选择远距离邻居(相似度最低的属性)- 20% (2/10) | |
| far_start = 2 * total_paths // 3 | |
| far_count = min(int(target_count * 2 / total_ratio), (total_paths - far_start)) | |
| far_indices = list(range(far_start, total_paths)) | |
| if far_indices: | |
| for i in random.sample(far_indices, min(far_count, len(far_indices))): | |
| if i not in used_indices: | |
| selected_paths.append(path_similarities[i][0]) | |
| used_indices.add(i) | |
| # 如果还需要更多属性来达到目标数量,从未使用的索引中随机选择 | |
| remaining_count = target_count - len(selected_paths) | |
| if remaining_count > 0: | |
| remaining_indices = [i for i in range(total_paths) if i not in used_indices] | |
| if remaining_indices: | |
| additional_count = min(remaining_count, len(remaining_indices)) | |
| for i in random.sample(remaining_indices, additional_count): | |
| selected_paths.append(path_similarities[i][0]) | |
| logger.info(f"使用向量搜索选择了 {len(selected_paths)} 条属性(近邻: {near_count}, 中距离: {mid_count}, 远距离: {far_count})") | |
| return selected_paths | |
| def get_top_attributes(self, result: Dict, target_count: int = 200) -> List[str]: | |
| """ | |
| 获取属性列表 | |
| 参数: | |
| result: 来自analyze_profile_for_attributes的结果 | |
| target_count: 目标属性数量 | |
| 返回: | |
| 属性路径列表 | |
| """ | |
| try: | |
| # 收集所有推荐的类别 | |
| all_recommended = set() | |
| if "recommended" in result: | |
| all_recommended.update(result["recommended"]) | |
| # 提取所有路径 | |
| all_paths = [] | |
| category_paths = {} | |
| if "paths" in self.attributes: | |
| # 如果我们使用带有路径的新格式 | |
| for path in self.attributes["paths"]: | |
| # 只包括来自推荐类别的路径 | |
| category = path.split('.')[0] if '.' in path else path | |
| if category in all_recommended: | |
| all_paths.append(path) | |
| # 按类别分组路径 | |
| if category not in category_paths: | |
| category_paths[category] = [] | |
| category_paths[category].append(path) | |
| else: | |
| # 如果我们使用旧的嵌套格式,将其扁平化 | |
| for category in all_recommended: | |
| paths = self._flatten_attributes(self._get_nested_attributes(category), category) | |
| all_paths.extend(paths) | |
| category_paths[category] = paths | |
| # 使用传入的target_count参数,不再随机选择 | |
| # 如果有向量数据库,使用向量搜索 | |
| if self.embeddings_data and self.user_profile: | |
| # 创建用户配置文件的嵌入向量 | |
| profile_embedding = self._create_profile_embedding(self.user_profile) | |
| if profile_embedding is not None: | |
| # 为每个类别选择属性 | |
| final_paths = [] | |
| for category, paths in category_paths.items(): | |
| # 根据类别大小按比例分配目标数量 | |
| category_ratio = len(paths) / len(all_paths) | |
| category_target = max(3, int(target_count * category_ratio)) | |
| # 使用向量搜索选择该类别的属性 | |
| category_selected = self._find_interesting_neighbors( | |
| profile_embedding, paths, category_target | |
| ) | |
| final_paths.extend(category_selected) | |
| # 如果选择的属性超过目标数量,随机减少 | |
| if len(final_paths) > target_count: | |
| final_paths = random.sample(final_paths, target_count) | |
| logger.info(f"使用向量搜索从 {len(all_paths)} 条属性中选择了 {len(final_paths)} 条属性") | |
| return final_paths | |
| # 如果没有向量数据库或向量搜索失败,直接返回空列表 | |
| logger.warning(f"没有可用的向量数据库或向量搜索失败,返回空列表") | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Error generating attribute list: {e}") | |
| # 如果出错,直接返回空列表 | |
| logger.error(f"属性选择过程出错,返回空列表") | |
| return [] | |
| # 注意:删除了_random_select_from_top_categories方法,因为现在直接在get_top_attributes中随机选择属性 | |
| def convert_web_profile_to_selector_format(web_profile: Dict) -> Dict: | |
| """ | |
| 将web_api_bridge.py保存的配置文件格式转换为AttributeSelector期望的格式 | |
| 核心原则:用户输入了什么就用什么,没输入的才生成 | |
| web_profile格式: | |
| { | |
| "age": int, | |
| "gender": str, | |
| "Occupations": [str], | |
| "location": {"city": str, "country": str}, | |
| "personal_values": {"values_orientation": str}, | |
| "life_attitude": str or dict, | |
| "personal_story": {"personal_story": str}, | |
| "interests": {"interests": [str]} | |
| } | |
| selector格式: | |
| { | |
| "age_info": {"age": int, "age_group": str}, | |
| "gender": str, | |
| "location": {"city": str, "country": str}, | |
| "career_info": {"status": str}, | |
| "personal_values": {"values_orientation": str}, | |
| "life_attitude": {"attitude": str, "attitude_details": str, "coping_mechanism": str}, | |
| "personal_story": {"personal_story": str}, | |
| "interests": {"interests": [str]} | |
| } | |
| """ | |
| # 提取年龄信息 - 如果用户提供了就用,否则生成 | |
| age = web_profile.get("age") | |
| if age: | |
| # 根据年龄确定年龄组 | |
| if age <= 6: | |
| age_group = "toddler" | |
| elif age <= 12: | |
| age_group = "child" | |
| elif age <= 19: | |
| age_group = "adolescent" | |
| elif age <= 29: | |
| age_group = "young_adult" | |
| elif age <= 45: | |
| age_group = "adult" | |
| elif age <= 65: | |
| age_group = "middle_aged" | |
| else: | |
| age_group = "senior" | |
| age_info = {"age": age, "age_group": age_group} | |
| print(f"✓ Using user-provided age: {age}") | |
| else: | |
| age_info = generate_age_info() | |
| print(f"✗ Age not provided, generated: {age_info['age']}") | |
| # 提取性别 - 如果用户提供了就用,否则生成 | |
| gender = web_profile.get("gender") | |
| if gender: | |
| print(f"✓ Using user-provided gender: {gender}") | |
| else: | |
| gender = generate_gender() | |
| print(f"✗ Gender not provided, generated: {gender}") | |
| # 提取位置 - 如果用户提供了就用,否则生成 | |
| location = web_profile.get("location") | |
| if location and location.get("city") and location.get("country"): | |
| print(f"✓ Using user-provided location: {location.get('city')}, {location.get('country')}") | |
| else: | |
| location = generate_location() | |
| print(f"✗ Location not provided, generated: {location.get('city')}, {location.get('country')}") | |
| # 提取职业信息 - 如果用户提供了就用,否则生成 | |
| occupations = web_profile.get("Occupations", []) | |
| if occupations and len(occupations) > 0 and occupations[0]: | |
| career_info = {"status": occupations[0]} | |
| print(f"✓ Using user-provided occupation: {occupations[0]}") | |
| else: | |
| career_info = generate_career_info(age_info["age"]) | |
| print(f"✗ Occupation not provided, generated: {career_info['status']}") | |
| # 提取个人价值观 - 如果用户提供了就用,否则生成 | |
| personal_values = web_profile.get("personal_values", {}) | |
| values_orientation = personal_values.get("values_orientation", "") | |
| if values_orientation and values_orientation.strip(): | |
| # 用户提供了价值观,直接使用 | |
| personal_values = {"values_orientation": values_orientation.strip()} | |
| print(f"✓ Using user-provided personal values: {values_orientation[:50]}...") | |
| else: | |
| # 用户没提供,基于已有信息生成 | |
| personal_values = generate_personal_values( | |
| age=age_info["age"], | |
| gender=gender, | |
| occupation=career_info["status"], | |
| location=location | |
| ) | |
| print(f"✗ Personal values not provided, generated based on user inputs") | |
| # 提取生活态度 - 如果用户提供了就用,否则生成 | |
| life_attitude_data = web_profile.get("life_attitude") | |
| if life_attitude_data: | |
| if isinstance(life_attitude_data, str) and life_attitude_data.strip(): | |
| # 用户提供了字符串格式的生活态度 | |
| life_attitude = { | |
| "attitude": life_attitude_data.strip(), | |
| "attitude_details": "", | |
| "coping_mechanism": "" | |
| } | |
| print(f"✓ Using user-provided life attitude: {life_attitude_data[:50]}...") | |
| elif isinstance(life_attitude_data, dict) and life_attitude_data.get("attitude"): | |
| # 用户提供了字典格式的生活态度 | |
| life_attitude = life_attitude_data | |
| print(f"✓ Using user-provided life attitude (dict format)") | |
| else: | |
| # 数据格式不对或为空,生成新的 | |
| life_attitude = generate_life_attitude( | |
| age=age_info["age"], | |
| gender=gender, | |
| occupation=career_info["status"], | |
| location=location, | |
| values_orientation=personal_values.get("values_orientation", "") | |
| ) | |
| print(f"✗ Life attitude not provided, generated based on user inputs") | |
| else: | |
| # 用户没提供,基于已有信息生成 | |
| life_attitude = generate_life_attitude( | |
| age=age_info["age"], | |
| gender=gender, | |
| occupation=career_info["status"], | |
| location=location, | |
| values_orientation=personal_values.get("values_orientation", "") | |
| ) | |
| print(f"✗ Life attitude not provided, generated based on user inputs") | |
| # 提取个人故事 - 如果用户提供了就用,否则生成 | |
| personal_story_data = web_profile.get("personal_story", {}) | |
| if isinstance(personal_story_data, dict) and personal_story_data.get("personal_story"): | |
| story_text = personal_story_data.get("personal_story", "") | |
| if story_text and story_text.strip(): | |
| # 用户提供了故事,直接使用 | |
| personal_story = {"personal_story": story_text.strip()} | |
| print(f"✓ Using user-provided life story: {story_text[:50]}...") | |
| else: | |
| # 故事为空,生成新的 | |
| personal_story = generate_personal_story( | |
| age=age_info["age"], | |
| gender=gender, | |
| occupation=career_info["status"], | |
| location=location, | |
| values_orientation=personal_values.get("values_orientation", ""), | |
| life_attitude=life_attitude | |
| ) | |
| print(f"✗ Life story not provided, generated based on user inputs") | |
| elif isinstance(personal_story_data, str) and personal_story_data.strip(): | |
| # 用户提供了字符串格式的故事 | |
| personal_story = {"personal_story": personal_story_data.strip()} | |
| print(f"✓ Using user-provided life story (string format)") | |
| else: | |
| # 用户没提供,基于已有信息生成 | |
| personal_story = generate_personal_story( | |
| age=age_info["age"], | |
| gender=gender, | |
| occupation=career_info["status"], | |
| location=location, | |
| values_orientation=personal_values.get("values_orientation", ""), | |
| life_attitude=life_attitude | |
| ) | |
| print(f"✗ Life story not provided, generated based on user inputs") | |
| # 提取兴趣爱好 - 如果用户提供了就用,否则生成 | |
| interests = web_profile.get("interests", {}) | |
| interests_list = interests.get("interests", []) | |
| if interests_list and len(interests_list) > 0 and any(i.strip() for i in interests_list): | |
| # 用户提供了兴趣爱好,直接使用 | |
| interests = {"interests": [i.strip() for i in interests_list if i.strip()]} | |
| print(f"✓ Using user-provided interests: {', '.join(interests['interests'])}") | |
| else: | |
| # 用户没提供,基于故事生成 | |
| interests = generate_interests_and_hobbies(personal_story) | |
| print(f"✗ Interests not provided, generated based on life story") | |
| # 返回转换后的格式 | |
| return { | |
| "age_info": age_info, | |
| "gender": gender, | |
| "location": location, | |
| "career_info": career_info, | |
| "personal_values": personal_values, | |
| "life_attitude": life_attitude, | |
| "personal_story": personal_story, | |
| "interests": interests | |
| } | |
| def generate_user_profile() -> Dict: | |
| """生成用户基础信息配置文件""" | |
| # 生成并存储直接函数返回值 | |
| age_info = generate_age_info() | |
| gender = generate_gender() | |
| location = generate_location() | |
| career_info = generate_career_info(age_info["age"]) | |
| # 生成个人价值观 | |
| values = generate_personal_values( | |
| age=age_info["age"], | |
| gender=gender, | |
| occupation=career_info["status"], | |
| location=location | |
| ) | |
| # 生成生活态度 | |
| life_attitude = generate_life_attitude( | |
| age=age_info["age"], | |
| gender=gender, | |
| occupation=career_info["status"], | |
| location=location, | |
| values_orientation=values.get("values_orientation", "") | |
| ) | |
| # 生成个人故事 | |
| personal_story = generate_personal_story( | |
| age=age_info["age"], | |
| gender=gender, | |
| occupation=career_info["status"], | |
| location=location, | |
| values_orientation=values.get("values_orientation", ""), | |
| life_attitude=life_attitude | |
| ) | |
| # 生成兴趣爱好 | |
| interests = generate_interests_and_hobbies(personal_story) | |
| # 存储函数返回值 | |
| user_profile = { | |
| "age_info": age_info, | |
| "gender": gender, | |
| "location": location, | |
| "career_info": career_info, | |
| "personal_values": values, | |
| "life_attitude": life_attitude, | |
| "personal_story": personal_story, | |
| "interests": interests | |
| } | |
| return user_profile | |
| def get_selected_attributes(user_profile=None, attribute_count=200): | |
| global ATTRIBUTE_SELECTION_CACHE | |
| # 注释掉缓存机制,确保每次都重新选择属性 | |
| # if ATTRIBUTE_SELECTION_CACHE is not None: | |
| # return ATTRIBUTE_SELECTION_CACHE | |
| try: | |
| # 如果没有提供用户配置文件,生成一个 | |
| if user_profile is None: | |
| user_profile = generate_user_profile() | |
| # 创建选择器并传入用户配置文件 | |
| selector = AttributeSelector(user_profile=user_profile) | |
| # 处理配置文件 | |
| result = selector.process_profile() | |
| # 获取属性推荐 | |
| attribute_recommendations = result.get("attribute_recommendations", {}) | |
| # 获取属性列表,使用传入的attribute_count参数 | |
| top_paths = selector.get_top_attributes(attribute_recommendations, target_count=attribute_count) | |
| # 返回属性列表 | |
| ATTRIBUTE_SELECTION_CACHE = top_paths | |
| return top_paths | |
| except Exception as e: | |
| logger.error(f"Error getting selected attributes: {e}") | |
| return [] | |
| def build_nested_dict(paths: List[str]) -> Dict: | |
| result = {} | |
| for path in paths: | |
| parts = path.split('.') | |
| current = result | |
| for part in parts: | |
| if part not in current: | |
| current[part] = {} | |
| current = current[part] | |
| return result | |
| def save_results(user_profile: Dict, selected_paths: List[str], output_dir: str = '/home/zhou/persona/generate_user_profile/output') -> None: | |
| """ | |
| 保存用户配置文件和选定的属性路径到文件 | |
| 参数: | |
| user_profile: 用户配置文件 | |
| selected_paths: 选定的属性路径 (列表形式) | |
| output_dir: 输出目录(默认为 '/home/zhou/persona/generate_user_profile/output') | |
| """ | |
| try: | |
| from pathlib import Path | |
| import json | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| # 保存用户配置文件 | |
| profile_path = output_path / "user_profile.json" | |
| with open(profile_path, 'w', encoding='utf-8') as f: | |
| json.dump(user_profile, f, ensure_ascii=False, indent=2) | |
| logger.info(f"用户配置文件已保存到 {profile_path}") | |
| # 将 selected_paths 转换为嵌套字典结构 | |
| nested_selected_paths = build_nested_dict(selected_paths) | |
| # 保存属性路径(嵌套字典格式) | |
| paths_path = output_path / "selected_paths.json" | |
| with open(paths_path, 'w', encoding='utf-8') as f: | |
| json.dump(nested_selected_paths, f, ensure_ascii=False, indent=2) | |
| logger.info(f"属性路径已保存到 {paths_path}") | |
| except Exception as e: | |
| logger.error(f"保存结果时出错: {e}") | |
| raise | |
| # 示例:在生成用户基本信息和属性列表的函数中自动调用保存(请根据实际情况将此调用添加到合适位置) | |
| user_profile = generate_user_profile() | |
| selected_paths = get_selected_attributes(user_profile) | |
| save_results(user_profile, selected_paths) | |
| # 此文件只供其他文件导入使用 |