# data_generator.py import torch import numpy as np from torch_geometric.data import HeteroData from datetime import datetime, timedelta import random from typing import Dict, List, Tuple from collections import defaultdict class IPEcosystemGenerator: """香港知识产权生态网络数据生成器""" def __init__(self, seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # 香港真实产业分布 self.industries = [ '金融科技', '生物医药', '人工智能', '半导体', '新能源', '电子商务', '物流科技', '智能制造' ] # 香港地区分布 self.districts = [ '中环', '湾仔', '尖沙咀', '观塘', '荃湾', '科学园', '数码港', '将军澳工业邨' ] # IPC分类 (国际专利分类) self.ipc_classes = { '金融科技': ['G06Q', 'G06F', 'H04L'], '生物医药': ['A61K', 'C07D', 'A61P', 'C12N'], '人工智能': ['G06N', 'G06K', 'G06T'], '半导体': ['H01L', 'H01S', 'G11C'], '新能源': ['H01M', 'H02J', 'F03D'], '电子商务': ['G06Q', 'H04W', 'G06F'], '物流科技': ['G01S', 'B65G', 'G06Q'], '智能制造': ['B25J', 'G05B', 'B23Q'] } def generate(self, n_companies: int = 500, n_patents: int = 3000, n_trademarks: int = 1500, n_persons: int = 2000, n_institutions: int = 50, time_span_years: int = 10) -> HeteroData: """ 生成异构图数据 参数: n_companies: 企业数量 n_patents: 专利数量 n_trademarks: 商标数量 n_persons: 人员数量(发明人/代理人) n_institutions: 机构数量 time_span_years: 时间跨度(年) """ data = HeteroData() # ========== 生成节点 ========== print("🏢 生成企业节点...") companies = self._generate_companies(n_companies) data['company'].x = companies['features'] data['company'].company_id = companies['ids'] data['company'].industry = companies['industry'] print("📜 生成专利节点...") patents = self._generate_patents(n_patents, time_span_years) data['patent'].x = patents['features'] data['patent'].patent_id = patents['ids'] data['patent'].year = patents['year'] print("™️ 生成商标节点...") trademarks = self._generate_trademarks(n_trademarks, time_span_years) data['trademark'].x = trademarks['features'] data['trademark'].trademark_id = trademarks['ids'] print("👤 生成人员节点...") persons = self._generate_persons(n_persons) data['person'].x = persons['features'] data['person'].person_id = persons['ids'] print("🏛️ 生成机构节点...") institutions = self._generate_institutions(n_institutions) data['institution'].x = institutions['features'] data['institution'].institution_id = institutions['ids'] # ========== 生成边关系 ========== print("\n🔗 生成关系边...") # 1. 企业 → 专利 (owns) company_patent_edges = self._generate_ownership_edges( n_companies, n_patents, companies['industry'], patents['ipc'] ) data['company', 'owns', 'patent'].edge_index = company_patent_edges['edge_index'] data['company', 'owns', 'patent'].edge_attr = company_patent_edges['edge_attr'] # 2. 企业 → 商标 (holds) company_trademark_edges = self._generate_trademark_edges( n_companies, n_trademarks, avg_per_company=3 ) data['company', 'holds', 'trademark'].edge_index = company_trademark_edges # 3. 人员 → 专利 (invents) person_patent_edges = self._generate_invention_edges( n_persons, n_patents, avg_inventors_per_patent=2.5 ) data['person', 'invents', 'patent'].edge_index = person_patent_edges['edge_index'] data['person', 'invents', 'patent'].edge_attr = person_patent_edges['edge_attr'] # 4. 企业 → 人员 (employs) company_person_edges = self._generate_employment_edges( n_companies, n_persons, company_patent_edges['edge_index'], person_patent_edges['edge_index'] ) data['company', 'employs', 'person'].edge_index = company_person_edges # 5. 专利 → 专利 (cites) patent_citation_edges = self._generate_citation_edges( n_patents, patents['year'], avg_citations=5 ) data['patent', 'cites', 'patent'].edge_index = patent_citation_edges['edge_index'] data['patent', 'cites', 'patent'].edge_attr = patent_citation_edges['edge_attr'] # 6. 企业 → 企业 (cooperates) company_cooperation_edges = self._generate_cooperation_edges( n_companies, companies['industry'], companies['district'] ) data['company', 'cooperates', 'company'].edge_index = company_cooperation_edges['edge_index'] data['company', 'cooperates', 'company'].edge_attr = company_cooperation_edges['edge_attr'] # 7. 企业 → 机构 (collaborates) company_institution_edges = self._generate_collaboration_edges( n_companies, n_institutions, companies['industry'] ) data['company', 'collaborates', 'institution'].edge_index = company_institution_edges # 8. 机构 → 专利 (produces) institution_patent_edges = self._generate_institution_patent_edges( n_institutions, n_patents ) data['institution', 'produces', 'patent'].edge_index = institution_patent_edges # ========== 添加反向边 ========== data['patent', 'owned_by', 'company'].edge_index = data['company', 'owns', 'patent'].edge_index.flip(0) data['trademark', 'held_by', 'company'].edge_index = data['company', 'holds', 'trademark'].edge_index.flip(0) data['patent', 'invented_by', 'person'].edge_index = data['person', 'invents', 'patent'].edge_index.flip(0) data['person', 'employed_by', 'company'].edge_index = data['company', 'employs', 'person'].edge_index.flip(0) data['institution', 'collaborated_by', 'company'].edge_index = data[ 'company', 'collaborates', 'institution'].edge_index.flip(0) data['patent', 'produced_by', 'institution'].edge_index = data[ 'institution', 'produces', 'patent'].edge_index.flip(0) print("\n✅ 数据生成完成!") self._print_statistics(data) return data def _generate_companies(self, n: int) -> Dict: """生成企业节点特征""" features = [] industries = [] districts = [] for i in range(n): # 企业规模分布 (幂律分布) size = np.random.pareto(2) * 50 + 10 # 10-500人 size = min(size, 500) # 成立年限 (0-30年) age = np.random.exponential(8) age = min(age, 30) # 研发投入比例 (0-30%) rd_ratio = np.random.beta(2, 5) * 0.3 # 产业选择 (有聚类效应) industry_idx = np.random.choice(len(self.industries), p=self._industry_distribution()) industry = self.industries[industry_idx] industries.append(industry_idx) # 地区选择 (产业相关) district_idx = self._select_district(industry) districts.append(district_idx) # 国际化程度 (0-1) international = np.random.beta(2, 5) # 创新能力评分 (0-100) innovation_score = rd_ratio * 200 + np.random.normal(50, 15) innovation_score = np.clip(innovation_score, 0, 100) # 年营收 (对数正态分布,单位:百万港币) revenue = np.random.lognormal(3, 1.5) feature = [ size / 500, # 归一化 age / 30, rd_ratio / 0.3, international, innovation_score / 100, np.log1p(revenue) / 10, industry_idx / len(self.industries), district_idx / len(self.districts) ] features.append(feature) return { 'features': torch.tensor(features, dtype=torch.float), 'ids': [f'HK-CO-{i:05d}' for i in range(n)], 'industry': torch.tensor(industries, dtype=torch.long), 'district': torch.tensor(districts, dtype=torch.long) } def _generate_patents(self, n: int, time_span: int) -> Dict: """生成专利节点特征""" features = [] years = [] ipcs = [] end_year = 2025 start_year = end_year - time_span for i in range(n): # 申请年份 (近年增长趋势) year_prob = np.linspace(0.5, 2.0, time_span) year_prob = year_prob / year_prob.sum() year = np.random.choice(range(start_year, end_year), p=year_prob) years.append(year) # IPC分类 industry = random.choice(self.industries) ipc = random.choice(self.ipc_classes[industry]) ipcs.append(ipc) # 权利要求数量 (1-30) claims = np.random.poisson(10) + 1 claims = min(claims, 30) # 发明人数量 (1-10) inventors = np.random.poisson(2) + 1 inventors = min(inventors, 10) # 引用数量 (随时间增长) age = end_year - year citations = np.random.poisson(age * 0.8) # 技术领域宽度 (1-5个IPC小类) tech_breadth = np.random.poisson(2) + 1 tech_breadth = min(tech_breadth, 5) # 专利价值评分 (0-100) value_score = (claims * 2 + citations * 5 + tech_breadth * 10) / 2 value_score = min(value_score, 100) # 是否获得授权 (时间越长概率越高) grant_prob = 0.3 + min(age * 0.1, 0.6) is_granted = 1.0 if random.random() < grant_prob else 0.0 feature = [ (year - start_year) / time_span, claims / 30, inventors / 10, citations / 50, tech_breadth / 5, value_score / 100, is_granted, hash(ipc) % 100 / 100 # IPC编码 ] features.append(feature) return { 'features': torch.tensor(features, dtype=torch.float), 'ids': [f'HK-PT-{i:06d}' for i in range(n)], 'year': torch.tensor(years, dtype=torch.long), 'ipc': ipcs } def _generate_trademarks(self, n: int, time_span: int) -> Dict: """生成商标节点特征""" features = [] end_year = 2025 start_year = end_year - time_span for i in range(n): # 注册年份 year = np.random.choice(range(start_year, end_year)) # 商标类别 (1-45类) nice_class = np.random.randint(1, 46) # 续展次数 (0-3) renewals = min(np.random.poisson(0.5), 3) # 商标类型: 文字/图形/组合 tm_type = np.random.choice([0, 1, 2], p=[0.4, 0.3, 0.3]) # 知名度评分 (0-100) fame_score = np.random.beta(2, 8) * 100 # 争议记录 (0-10) disputes = np.random.poisson(0.3) disputes = min(disputes, 10) feature = [ (year - start_year) / time_span, nice_class / 45, renewals / 3, tm_type / 2, fame_score / 100, disputes / 10 ] features.append(feature) return { 'features': torch.tensor(features, dtype=torch.float), 'ids': [f'HK-TM-{i:06d}' for i in range(n)] } def _generate_persons(self, n: int) -> Dict: """生成人员节点特征""" features = [] for i in range(n): # 学历 (0: 本科, 1: 硕士, 2: 博士) education = np.random.choice([0, 1, 2], p=[0.3, 0.5, 0.2]) # 工作年限 (0-40年) experience = np.random.gamma(3, 3) experience = min(experience, 40) # 专利发明数量 patent_count = np.random.negative_binomial(2, 0.3) patent_count = min(patent_count, 50) # 技术领域数量 (1-5) tech_fields = np.random.poisson(2) + 1 tech_fields = min(tech_fields, 5) # H指数 (科研影响力) h_index = np.random.poisson(education * 5 + 2) h_index = min(h_index, 50) # 跨界合作能力 (0-1) collaboration = np.random.beta(3, 3) feature = [ education / 2, experience / 40, patent_count / 50, tech_fields / 5, h_index / 50, collaboration ] features.append(feature) return { 'features': torch.tensor(features, dtype=torch.float), 'ids': [f'HK-PS-{i:05d}' for i in range(n)] } def _generate_institutions(self, n: int) -> Dict: """生成机构节点特征""" features = [] institution_types = ['大学', '研究所', '孵化器', '政府实验室'] for i in range(n): # 机构类型 inst_type = np.random.choice(len(institution_types)) # 建立年限 age = np.random.exponential(15) age = min(age, 100) # 研究人员数量 researchers = np.random.lognormal(4, 1) researchers = min(researchers, 1000) # 年度专利产出 patent_output = np.random.poisson(20) # 国际排名 (归一化后的倒数) ranking = np.random.pareto(2) + 1 ranking_score = 1 / ranking # 产学研合作数量 industry_collab = np.random.poisson(15) feature = [ inst_type / len(institution_types), age / 100, np.log1p(researchers) / np.log1p(1000), patent_output / 100, ranking_score, industry_collab / 50 ] features.append(feature) return { 'features': torch.tensor(features, dtype=torch.float), 'ids': [f'HK-IN-{i:04d}' for i in range(n)] } def _generate_ownership_edges(self, n_companies: int, n_patents: int, company_industry: torch.Tensor, patent_ipc: List[str]) -> Dict: """生成企业-专利所有权边""" edges = [] edge_attrs = [] # 每个专利1-3个所有者 for patent_idx in range(n_patents): n_owners = np.random.choice([1, 2, 3], p=[0.7, 0.25, 0.05]) # 选择同产业或相关产业的企业 patent_ipc_code = patent_ipc[patent_idx] # 找到相关产业的企业 candidate_companies = [] for industry_idx, industry in enumerate(self.industries): if patent_ipc_code in self.ipc_classes[industry]: candidate_companies.extend( (company_industry == industry_idx).nonzero(as_tuple=True)[0].tolist() ) if not candidate_companies: candidate_companies = list(range(n_companies)) owners = np.random.choice(candidate_companies, size=min(n_owners, len(candidate_companies)), replace=False) for company_idx in owners: edges.append([company_idx, patent_idx]) # 边属性: 所有权份额 ownership_share = 1.0 / n_owners # 重要性评分 importance = np.random.beta(3, 2) edge_attrs.append([ownership_share, importance]) return { 'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(), 'edge_attr': torch.tensor(edge_attrs, dtype=torch.float) } def _generate_trademark_edges(self, n_companies: int, n_trademarks: int, avg_per_company: float) -> torch.Tensor: """生成企业-商标持有边""" edges = [] # 每个企业持有的商标数量 (泊松分布) for company_idx in range(n_companies): n_tm = np.random.poisson(avg_per_company) if n_tm > 0: trademarks = np.random.choice(n_trademarks, size=min(n_tm, n_trademarks), replace=False) for tm_idx in trademarks: edges.append([company_idx, tm_idx]) return torch.tensor(edges, dtype=torch.long).t().contiguous() def _generate_invention_edges(self, n_persons: int, n_patents: int, avg_inventors_per_patent: float) -> Dict: """生成人员-专利发明边""" edges = [] edge_attrs = [] for patent_idx in range(n_patents): n_inventors = max(1, int(np.random.poisson(avg_inventors_per_patent))) n_inventors = min(n_inventors, 10) inventors = np.random.choice(n_persons, size=n_inventors, replace=False) for rank, person_idx in enumerate(inventors): edges.append([person_idx, patent_idx]) # 边属性: 发明人排序 (第一发明人=1.0) inventor_rank = 1.0 - (rank / n_inventors) # 贡献度 contribution = np.random.beta(5, 2) if rank == 0 else np.random.beta(2, 5) edge_attrs.append([inventor_rank, contribution]) return { 'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(), 'edge_attr': torch.tensor(edge_attrs, dtype=torch.float) } def _generate_employment_edges(self, n_companies: int, n_persons: int, company_patent_edges: torch.Tensor, person_patent_edges: torch.Tensor) -> torch.Tensor: """生成企业-人员雇佣边 (基于专利合作推断)""" edges = set() # 基于共同专利推断雇佣关系 company_patents = defaultdict(set) person_patents = defaultdict(set) for company, patent in company_patent_edges.t().tolist(): company_patents[company].add(patent) for person, patent in person_patent_edges.t().tolist(): person_patents[person].add(patent) for person_idx in range(n_persons): person_patent_set = person_patents[person_idx] if not person_patent_set: continue # 找到拥有相同专利的企业 candidate_companies = [] for company_idx in range(n_companies): overlap = len(company_patents[company_idx] & person_patent_set) if overlap > 0: candidate_companies.append((company_idx, overlap)) if candidate_companies: # 选择专利重叠最多的企业 candidate_companies.sort(key=lambda x: x[1], reverse=True) employer = candidate_companies[0][0] edges.add((employer, person_idx)) else: # 随机分配雇主 employer = np.random.randint(0, n_companies) edges.add((employer, person_idx)) return torch.tensor(list(edges), dtype=torch.long).t().contiguous() def _generate_citation_edges(self, n_patents: int, patent_years: torch.Tensor, avg_citations: float) -> Dict: """生成专利引用边 (只能引用更早的专利)""" edges = [] edge_attrs = [] for patent_idx in range(n_patents): current_year = patent_years[patent_idx].item() # 可引用的专利 (更早的专利) earlier_patents = (patent_years < current_year).nonzero(as_tuple=True)[0] if len(earlier_patents) > 0: n_citations = np.random.poisson(avg_citations) n_citations = min(n_citations, len(earlier_patents)) if n_citations > 0: # 偏向引用更近期的专利 weights = 1.0 / (current_year - patent_years[earlier_patents].float() + 1) weights = weights / weights.sum() cited_patents = np.random.choice( earlier_patents.numpy(), size=n_citations, replace=False, p=weights.numpy() ) for cited_idx in cited_patents: edges.append([patent_idx, cited_idx]) # 引用重要性 citation_importance = np.random.beta(3, 2) # 时间差 time_diff = (current_year - patent_years[cited_idx].item()) / 10.0 edge_attrs.append([citation_importance, time_diff]) return { 'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(), 'edge_attr': torch.tensor(edge_attrs, dtype=torch.float) } def _generate_cooperation_edges(self, n_companies: int, company_industry: torch.Tensor, company_district: torch.Tensor) -> Dict: """生成企业合作边""" edges = [] edge_attrs = [] for company_idx in range(n_companies): # 每个企业0-5个合作伙伴 n_partners = np.random.poisson(2) n_partners = min(n_partners, 5) if n_partners > 0: # 优先选择同产业或同地区的企业 same_industry = (company_industry == company_industry[company_idx]).nonzero(as_tuple=True)[0] same_district = (company_district == company_district[company_idx]).nonzero(as_tuple=True)[0] # 去掉自己 same_industry = same_industry[same_industry != company_idx] same_district = same_district[same_district != company_idx] # 70% 同产业, 30% 跨产业 candidates = [] if len(same_industry) > 0 and random.random() < 0.7: candidates = same_industry.tolist() else: candidates = list(range(n_companies)) candidates.remove(company_idx) if candidates: partners = np.random.choice(candidates, size=min(n_partners, len(candidates)), replace=False) for partner_idx in partners: # 避免重复边 if company_idx < partner_idx: edges.append([company_idx, partner_idx]) # 合作强度 strength = np.random.beta(3, 3) # 合作时长 (年) duration = np.random.exponential(2) duration = min(duration, 10) / 10 edge_attrs.append([strength, duration]) return { 'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(), 'edge_attr': torch.tensor(edge_attrs, dtype=torch.float) } def _generate_collaboration_edges(self, n_companies: int, n_institutions: int, company_industry: torch.Tensor) -> torch.Tensor: """生成企业-机构合作边""" edges = [] for company_idx in range(n_companies): # 30% 的企业与机构有合作 if random.random() < 0.3: n_collabs = np.random.poisson(1) + 1 n_collabs = min(n_collabs, 3) institutions = np.random.choice(n_institutions, size=n_collabs, replace=False) for inst_idx in institutions: edges.append([company_idx, inst_idx]) return torch.tensor(edges, dtype=torch.long).t().contiguous() def _generate_institution_patent_edges(self, n_institutions: int, n_patents: int) -> torch.Tensor: """生成机构-专利产出边""" edges = [] # 每个机构产出多个专利 for inst_idx in range(n_institutions): n_patents_produced = np.random.poisson(20) n_patents_produced = min(n_patents_produced, n_patents // 2) if n_patents_produced > 0: patents = np.random.choice(n_patents, size=n_patents_produced, replace=False) for patent_idx in patents: edges.append([inst_idx, patent_idx]) return torch.tensor(edges, dtype=torch.long).t().contiguous() def _industry_distribution(self) -> np.ndarray: """产业分布 (香港实际情况: 金融科技和AI占比较大)""" weights = np.array([0.25, 0.15, 0.20, 0.10, 0.08, 0.10, 0.07, 0.05]) return weights / weights.sum() def _select_district(self, industry: str) -> int: """根据产业选择地区""" district_map = { '金融科技': ['中环', '湾仔', '尖沙咀'], '生物医药': ['科学园', '将军澳工业邨'], '人工智能': ['科学园', '数码港'], '半导体': ['科学园', '将军澳工业邨'], '新能源': ['科学园', '观塘'], '电子商务': ['数码港', '荃湾'], '物流科技': ['观塘', '荃湾'], '智能制造': ['将军澳工业邨', '观塘'] } preferred = district_map.get(industry, self.districts) district = random.choice(preferred) return self.districts.index(district) def _print_statistics(self, data: HeteroData): """打印数据统计信息""" print("\n" + "=" * 60) print("📊 数据集统计") print("=" * 60) print(f"\n【节点数量】") for node_type in data.node_types: print(f" {node_type:15s}: {data[node_type].num_nodes:6d} 节点") print(f"\n【边数量】") for edge_type in data.edge_types: src, rel, dst = edge_type num_edges = data[edge_type].num_edges print(f" {src:12s} --[{rel:12s}]--> {dst:12s}: {num_edges:6d} 条边") print(f"\n【特征维度】") for node_type in data.node_types: print(f" {node_type:15s}: {data[node_type].num_features} 维") print("\n" + "=" * 60) @staticmethod def load_data(dataset_size='medium', data_dir='data'): """ 加载已保存的数据集 参数: dataset_size: 数据集规模 ('test', 'medium', 'large') data_dir: 数据目录路径 返回: HeteroData: 加载的异构图数据 """ import os filename_map = { 'test': 'hk_ip_test.pt', 'medium': 'hk_ip_medium.pt', 'large': 'hk_ip_large.pt' } if dataset_size not in filename_map: raise ValueError(f"不支持的数据集规模: {dataset_size}. 支持的选项: {list(filename_map.keys())}") filepath = os.path.join(data_dir, filename_map[dataset_size]) if not os.path.exists(filepath): print(f"❌ 数据文件不存在: {filepath}") print(f"🔄 正在生成 {dataset_size} 规模的数据集...") # 自动生成缺失的数据集 generator = IPEcosystemGenerator(seed=42) size_configs = { 'test': { 'n_companies': 100, 'n_patents': 500, 'n_trademarks': 200, 'n_persons': 300, 'n_institutions': 20, 'time_span_years': 5 }, 'medium': { 'n_companies': 500, 'n_patents': 3000, 'n_trademarks': 1500, 'n_persons': 2000, 'n_institutions': 50, 'time_span_years': 10 }, 'large': { 'n_companies': 2000, 'n_patents': 15000, 'n_trademarks': 8000, 'n_persons': 10000, 'n_institutions': 200, 'time_span_years': 15 } } config = size_configs[dataset_size] data = generator.generate(**config) # 确保目录存在 os.makedirs(data_dir, exist_ok=True) torch.save(data, filepath) print(f"✅ 数据已生成并保存到: {filepath}") return data try: print(f"📂 加载数据集: {filepath}") data = torch.load(filepath, weights_only=False) print(f"✅ 成功加载 {dataset_size} 规模数据集") # 打印数据集信息 print(f"\n📊 数据集概览:") for node_type in data.node_types: print(f" {node_type:15s}: {data[node_type].num_nodes:6d} 节点") total_edges = sum(data[edge_type].num_edges for edge_type in data.edge_types) print(f" {'总边数':15s}: {total_edges:6d} 条边") return data except Exception as e: print(f"❌ 加载数据失败: {e}") print(f"🔄 尝试重新生成数据...") # 删除损坏的文件 if os.path.exists(filepath): os.remove(filepath) # 递归调用重新生成 return IPEcosystemGenerator.load_data(dataset_size, data_dir) @staticmethod def list_available_datasets(data_dir='data'): """ 列出可用的数据集 参数: data_dir: 数据目录路径 """ import os if not os.path.exists(data_dir): print(f"📁 数据目录不存在: {data_dir}") return filename_map = { 'hk_ip_test.pt': 'test (小规模)', 'hk_ip_medium.pt': 'medium (中等规模)', 'hk_ip_large.pt': 'large (大规模)' } print(f"\n📋 可用数据集 (目录: {data_dir}):") print("-" * 50) found_any = False for filename, description in filename_map.items(): filepath = os.path.join(data_dir, filename) if os.path.exists(filepath): file_size = os.path.getsize(filepath) / (1024 * 1024) # MB print(f"✅ {description:20s} - {filename:20s} ({file_size:.1f} MB)") found_any = True else: print(f"❌ {description:20s} - {filename:20s} (不存在)") if not found_any: print("🚫 未找到任何数据集文件") print("💡 运行 data_generator.py 来生成数据集") print("-" * 50) @staticmethod def get_dataset_info(dataset_size='medium', data_dir='data'): """ 获取数据集详细信息而不加载数据 参数: dataset_size: 数据集规模 data_dir: 数据目录路径 """ filename_map = { 'test': 'hk_ip_test.pt', 'medium': 'hk_ip_medium.pt', 'large': 'hk_ip_large.pt' } filepath = os.path.join(data_dir, filename_map[dataset_size]) if not os.path.exists(filepath): print(f"❌ 数据文件不存在: {filepath}") return None try: # 只加载metadata而不加载完整数据 data = torch.load(filepath, map_location='cpu') info = { 'dataset_size': dataset_size, 'filepath': filepath, 'file_size_mb': os.path.getsize(filepath) / (1024 * 1024), 'node_types': data.node_types, 'edge_types': data.edge_types, 'node_counts': {ntype: data[ntype].num_nodes for ntype in data.node_types}, 'edge_counts': {etype: data[etype].num_edges for etype in data.edge_types}, 'total_nodes': sum(data[ntype].num_nodes for ntype in data.node_types), 'total_edges': sum(data[etype].num_edges for etype in data.edge_types) } return info except Exception as e: print(f"❌ 读取数据集信息失败: {e}") return None # ========== 使用示例 (更新) ========== if __name__ == "__main__": # 生成数据集 generator = IPEcosystemGenerator(seed=42) # 确保data目录存在 import os os.makedirs('../data', exist_ok=True) # 检查现有数据集 print("🔍 检查现有数据集...") IPEcosystemGenerator.list_available_datasets() # 小规模测试 print("\n🧪 生成测试数据集...") test_data = generator.generate( n_companies=100, n_patents=500, n_trademarks=200, n_persons=300, n_institutions=20, time_span_years=5 ) # 中等规模 print("\n🏗️ 生成中等规模数据集...") medium_data = generator.generate( n_companies=500, n_patents=3000, n_trademarks=1500, n_persons=2000, n_institutions=50, time_span_years=10 ) # 大规模 print("\n🚀 生成大规模数据集...") large_data = generator.generate( n_companies=2000, n_patents=15000, n_trademarks=8000, n_persons=10000, n_institutions=200, time_span_years=15 ) # 保存数据 torch.save(test_data, '../data/hk_ip_test.pt') torch.save(medium_data, '../data/hk_ip_medium.pt') torch.save(large_data, '../data/hk_ip_large.pt') print("\n✅ 数据已保存到 data/ 目录") # 显示最终状态 print("\n" + "=" * 60) IPEcosystemGenerator.list_available_datasets() # 测试加载功能 print("\n🧪 测试数据加载功能...") loaded_data = IPEcosystemGenerator.load_data('medium') print("✅ 数据加载测试成功!")