Spaces:
Sleeping
Sleeping
| # data_generator.py | |
| import torch | |
| import numpy as np | |
| from torch_geometric.data import HeteroData | |
| from datetime import datetime, timedelta | |
| import random | |
| from typing import Dict, List, Tuple | |
| from collections import defaultdict | |
| class IPEcosystemGenerator: | |
| """香港知识产权生态网络数据生成器""" | |
| def __init__(self, seed=42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| # 香港真实产业分布 | |
| self.industries = [ | |
| '金融科技', '生物医药', '人工智能', '半导体', | |
| '新能源', '电子商务', '物流科技', '智能制造' | |
| ] | |
| # 香港地区分布 | |
| self.districts = [ | |
| '中环', '湾仔', '尖沙咀', '观塘', '荃湾', | |
| '科学园', '数码港', '将军澳工业邨' | |
| ] | |
| # IPC分类 (国际专利分类) | |
| self.ipc_classes = { | |
| '金融科技': ['G06Q', 'G06F', 'H04L'], | |
| '生物医药': ['A61K', 'C07D', 'A61P', 'C12N'], | |
| '人工智能': ['G06N', 'G06K', 'G06T'], | |
| '半导体': ['H01L', 'H01S', 'G11C'], | |
| '新能源': ['H01M', 'H02J', 'F03D'], | |
| '电子商务': ['G06Q', 'H04W', 'G06F'], | |
| '物流科技': ['G01S', 'B65G', 'G06Q'], | |
| '智能制造': ['B25J', 'G05B', 'B23Q'] | |
| } | |
| def generate(self, | |
| n_companies: int = 500, | |
| n_patents: int = 3000, | |
| n_trademarks: int = 1500, | |
| n_persons: int = 2000, | |
| n_institutions: int = 50, | |
| time_span_years: int = 10) -> HeteroData: | |
| """ | |
| 生成异构图数据 | |
| 参数: | |
| n_companies: 企业数量 | |
| n_patents: 专利数量 | |
| n_trademarks: 商标数量 | |
| n_persons: 人员数量(发明人/代理人) | |
| n_institutions: 机构数量 | |
| time_span_years: 时间跨度(年) | |
| """ | |
| data = HeteroData() | |
| # ========== 生成节点 ========== | |
| print("🏢 生成企业节点...") | |
| companies = self._generate_companies(n_companies) | |
| data['company'].x = companies['features'] | |
| data['company'].company_id = companies['ids'] | |
| data['company'].industry = companies['industry'] | |
| print("📜 生成专利节点...") | |
| patents = self._generate_patents(n_patents, time_span_years) | |
| data['patent'].x = patents['features'] | |
| data['patent'].patent_id = patents['ids'] | |
| data['patent'].year = patents['year'] | |
| print("™️ 生成商标节点...") | |
| trademarks = self._generate_trademarks(n_trademarks, time_span_years) | |
| data['trademark'].x = trademarks['features'] | |
| data['trademark'].trademark_id = trademarks['ids'] | |
| print("👤 生成人员节点...") | |
| persons = self._generate_persons(n_persons) | |
| data['person'].x = persons['features'] | |
| data['person'].person_id = persons['ids'] | |
| print("🏛️ 生成机构节点...") | |
| institutions = self._generate_institutions(n_institutions) | |
| data['institution'].x = institutions['features'] | |
| data['institution'].institution_id = institutions['ids'] | |
| # ========== 生成边关系 ========== | |
| print("\n🔗 生成关系边...") | |
| # 1. 企业 → 专利 (owns) | |
| company_patent_edges = self._generate_ownership_edges( | |
| n_companies, n_patents, companies['industry'], patents['ipc'] | |
| ) | |
| data['company', 'owns', 'patent'].edge_index = company_patent_edges['edge_index'] | |
| data['company', 'owns', 'patent'].edge_attr = company_patent_edges['edge_attr'] | |
| # 2. 企业 → 商标 (holds) | |
| company_trademark_edges = self._generate_trademark_edges( | |
| n_companies, n_trademarks, avg_per_company=3 | |
| ) | |
| data['company', 'holds', 'trademark'].edge_index = company_trademark_edges | |
| # 3. 人员 → 专利 (invents) | |
| person_patent_edges = self._generate_invention_edges( | |
| n_persons, n_patents, avg_inventors_per_patent=2.5 | |
| ) | |
| data['person', 'invents', 'patent'].edge_index = person_patent_edges['edge_index'] | |
| data['person', 'invents', 'patent'].edge_attr = person_patent_edges['edge_attr'] | |
| # 4. 企业 → 人员 (employs) | |
| company_person_edges = self._generate_employment_edges( | |
| n_companies, n_persons, company_patent_edges['edge_index'], | |
| person_patent_edges['edge_index'] | |
| ) | |
| data['company', 'employs', 'person'].edge_index = company_person_edges | |
| # 5. 专利 → 专利 (cites) | |
| patent_citation_edges = self._generate_citation_edges( | |
| n_patents, patents['year'], avg_citations=5 | |
| ) | |
| data['patent', 'cites', 'patent'].edge_index = patent_citation_edges['edge_index'] | |
| data['patent', 'cites', 'patent'].edge_attr = patent_citation_edges['edge_attr'] | |
| # 6. 企业 → 企业 (cooperates) | |
| company_cooperation_edges = self._generate_cooperation_edges( | |
| n_companies, companies['industry'], companies['district'] | |
| ) | |
| data['company', 'cooperates', 'company'].edge_index = company_cooperation_edges['edge_index'] | |
| data['company', 'cooperates', 'company'].edge_attr = company_cooperation_edges['edge_attr'] | |
| # 7. 企业 → 机构 (collaborates) | |
| company_institution_edges = self._generate_collaboration_edges( | |
| n_companies, n_institutions, companies['industry'] | |
| ) | |
| data['company', 'collaborates', 'institution'].edge_index = company_institution_edges | |
| # 8. 机构 → 专利 (produces) | |
| institution_patent_edges = self._generate_institution_patent_edges( | |
| n_institutions, n_patents | |
| ) | |
| data['institution', 'produces', 'patent'].edge_index = institution_patent_edges | |
| # ========== 添加反向边 ========== | |
| data['patent', 'owned_by', 'company'].edge_index = data['company', 'owns', 'patent'].edge_index.flip(0) | |
| data['trademark', 'held_by', 'company'].edge_index = data['company', 'holds', 'trademark'].edge_index.flip(0) | |
| data['patent', 'invented_by', 'person'].edge_index = data['person', 'invents', 'patent'].edge_index.flip(0) | |
| data['person', 'employed_by', 'company'].edge_index = data['company', 'employs', 'person'].edge_index.flip(0) | |
| data['institution', 'collaborated_by', 'company'].edge_index = data[ | |
| 'company', 'collaborates', 'institution'].edge_index.flip(0) | |
| data['patent', 'produced_by', 'institution'].edge_index = data[ | |
| 'institution', 'produces', 'patent'].edge_index.flip(0) | |
| print("\n✅ 数据生成完成!") | |
| self._print_statistics(data) | |
| return data | |
| def _generate_companies(self, n: int) -> Dict: | |
| """生成企业节点特征""" | |
| features = [] | |
| industries = [] | |
| districts = [] | |
| for i in range(n): | |
| # 企业规模分布 (幂律分布) | |
| size = np.random.pareto(2) * 50 + 10 # 10-500人 | |
| size = min(size, 500) | |
| # 成立年限 (0-30年) | |
| age = np.random.exponential(8) | |
| age = min(age, 30) | |
| # 研发投入比例 (0-30%) | |
| rd_ratio = np.random.beta(2, 5) * 0.3 | |
| # 产业选择 (有聚类效应) | |
| industry_idx = np.random.choice(len(self.industries), | |
| p=self._industry_distribution()) | |
| industry = self.industries[industry_idx] | |
| industries.append(industry_idx) | |
| # 地区选择 (产业相关) | |
| district_idx = self._select_district(industry) | |
| districts.append(district_idx) | |
| # 国际化程度 (0-1) | |
| international = np.random.beta(2, 5) | |
| # 创新能力评分 (0-100) | |
| innovation_score = rd_ratio * 200 + np.random.normal(50, 15) | |
| innovation_score = np.clip(innovation_score, 0, 100) | |
| # 年营收 (对数正态分布,单位:百万港币) | |
| revenue = np.random.lognormal(3, 1.5) | |
| feature = [ | |
| size / 500, # 归一化 | |
| age / 30, | |
| rd_ratio / 0.3, | |
| international, | |
| innovation_score / 100, | |
| np.log1p(revenue) / 10, | |
| industry_idx / len(self.industries), | |
| district_idx / len(self.districts) | |
| ] | |
| features.append(feature) | |
| return { | |
| 'features': torch.tensor(features, dtype=torch.float), | |
| 'ids': [f'HK-CO-{i:05d}' for i in range(n)], | |
| 'industry': torch.tensor(industries, dtype=torch.long), | |
| 'district': torch.tensor(districts, dtype=torch.long) | |
| } | |
| def _generate_patents(self, n: int, time_span: int) -> Dict: | |
| """生成专利节点特征""" | |
| features = [] | |
| years = [] | |
| ipcs = [] | |
| end_year = 2025 | |
| start_year = end_year - time_span | |
| for i in range(n): | |
| # 申请年份 (近年增长趋势) | |
| year_prob = np.linspace(0.5, 2.0, time_span) | |
| year_prob = year_prob / year_prob.sum() | |
| year = np.random.choice(range(start_year, end_year), p=year_prob) | |
| years.append(year) | |
| # IPC分类 | |
| industry = random.choice(self.industries) | |
| ipc = random.choice(self.ipc_classes[industry]) | |
| ipcs.append(ipc) | |
| # 权利要求数量 (1-30) | |
| claims = np.random.poisson(10) + 1 | |
| claims = min(claims, 30) | |
| # 发明人数量 (1-10) | |
| inventors = np.random.poisson(2) + 1 | |
| inventors = min(inventors, 10) | |
| # 引用数量 (随时间增长) | |
| age = end_year - year | |
| citations = np.random.poisson(age * 0.8) | |
| # 技术领域宽度 (1-5个IPC小类) | |
| tech_breadth = np.random.poisson(2) + 1 | |
| tech_breadth = min(tech_breadth, 5) | |
| # 专利价值评分 (0-100) | |
| value_score = (claims * 2 + citations * 5 + tech_breadth * 10) / 2 | |
| value_score = min(value_score, 100) | |
| # 是否获得授权 (时间越长概率越高) | |
| grant_prob = 0.3 + min(age * 0.1, 0.6) | |
| is_granted = 1.0 if random.random() < grant_prob else 0.0 | |
| feature = [ | |
| (year - start_year) / time_span, | |
| claims / 30, | |
| inventors / 10, | |
| citations / 50, | |
| tech_breadth / 5, | |
| value_score / 100, | |
| is_granted, | |
| hash(ipc) % 100 / 100 # IPC编码 | |
| ] | |
| features.append(feature) | |
| return { | |
| 'features': torch.tensor(features, dtype=torch.float), | |
| 'ids': [f'HK-PT-{i:06d}' for i in range(n)], | |
| 'year': torch.tensor(years, dtype=torch.long), | |
| 'ipc': ipcs | |
| } | |
| def _generate_trademarks(self, n: int, time_span: int) -> Dict: | |
| """生成商标节点特征""" | |
| features = [] | |
| end_year = 2025 | |
| start_year = end_year - time_span | |
| for i in range(n): | |
| # 注册年份 | |
| year = np.random.choice(range(start_year, end_year)) | |
| # 商标类别 (1-45类) | |
| nice_class = np.random.randint(1, 46) | |
| # 续展次数 (0-3) | |
| renewals = min(np.random.poisson(0.5), 3) | |
| # 商标类型: 文字/图形/组合 | |
| tm_type = np.random.choice([0, 1, 2], p=[0.4, 0.3, 0.3]) | |
| # 知名度评分 (0-100) | |
| fame_score = np.random.beta(2, 8) * 100 | |
| # 争议记录 (0-10) | |
| disputes = np.random.poisson(0.3) | |
| disputes = min(disputes, 10) | |
| feature = [ | |
| (year - start_year) / time_span, | |
| nice_class / 45, | |
| renewals / 3, | |
| tm_type / 2, | |
| fame_score / 100, | |
| disputes / 10 | |
| ] | |
| features.append(feature) | |
| return { | |
| 'features': torch.tensor(features, dtype=torch.float), | |
| 'ids': [f'HK-TM-{i:06d}' for i in range(n)] | |
| } | |
| def _generate_persons(self, n: int) -> Dict: | |
| """生成人员节点特征""" | |
| features = [] | |
| for i in range(n): | |
| # 学历 (0: 本科, 1: 硕士, 2: 博士) | |
| education = np.random.choice([0, 1, 2], p=[0.3, 0.5, 0.2]) | |
| # 工作年限 (0-40年) | |
| experience = np.random.gamma(3, 3) | |
| experience = min(experience, 40) | |
| # 专利发明数量 | |
| patent_count = np.random.negative_binomial(2, 0.3) | |
| patent_count = min(patent_count, 50) | |
| # 技术领域数量 (1-5) | |
| tech_fields = np.random.poisson(2) + 1 | |
| tech_fields = min(tech_fields, 5) | |
| # H指数 (科研影响力) | |
| h_index = np.random.poisson(education * 5 + 2) | |
| h_index = min(h_index, 50) | |
| # 跨界合作能力 (0-1) | |
| collaboration = np.random.beta(3, 3) | |
| feature = [ | |
| education / 2, | |
| experience / 40, | |
| patent_count / 50, | |
| tech_fields / 5, | |
| h_index / 50, | |
| collaboration | |
| ] | |
| features.append(feature) | |
| return { | |
| 'features': torch.tensor(features, dtype=torch.float), | |
| 'ids': [f'HK-PS-{i:05d}' for i in range(n)] | |
| } | |
| def _generate_institutions(self, n: int) -> Dict: | |
| """生成机构节点特征""" | |
| features = [] | |
| institution_types = ['大学', '研究所', '孵化器', '政府实验室'] | |
| for i in range(n): | |
| # 机构类型 | |
| inst_type = np.random.choice(len(institution_types)) | |
| # 建立年限 | |
| age = np.random.exponential(15) | |
| age = min(age, 100) | |
| # 研究人员数量 | |
| researchers = np.random.lognormal(4, 1) | |
| researchers = min(researchers, 1000) | |
| # 年度专利产出 | |
| patent_output = np.random.poisson(20) | |
| # 国际排名 (归一化后的倒数) | |
| ranking = np.random.pareto(2) + 1 | |
| ranking_score = 1 / ranking | |
| # 产学研合作数量 | |
| industry_collab = np.random.poisson(15) | |
| feature = [ | |
| inst_type / len(institution_types), | |
| age / 100, | |
| np.log1p(researchers) / np.log1p(1000), | |
| patent_output / 100, | |
| ranking_score, | |
| industry_collab / 50 | |
| ] | |
| features.append(feature) | |
| return { | |
| 'features': torch.tensor(features, dtype=torch.float), | |
| 'ids': [f'HK-IN-{i:04d}' for i in range(n)] | |
| } | |
| def _generate_ownership_edges(self, n_companies: int, n_patents: int, | |
| company_industry: torch.Tensor, | |
| patent_ipc: List[str]) -> Dict: | |
| """生成企业-专利所有权边""" | |
| edges = [] | |
| edge_attrs = [] | |
| # 每个专利1-3个所有者 | |
| for patent_idx in range(n_patents): | |
| n_owners = np.random.choice([1, 2, 3], p=[0.7, 0.25, 0.05]) | |
| # 选择同产业或相关产业的企业 | |
| patent_ipc_code = patent_ipc[patent_idx] | |
| # 找到相关产业的企业 | |
| candidate_companies = [] | |
| for industry_idx, industry in enumerate(self.industries): | |
| if patent_ipc_code in self.ipc_classes[industry]: | |
| candidate_companies.extend( | |
| (company_industry == industry_idx).nonzero(as_tuple=True)[0].tolist() | |
| ) | |
| if not candidate_companies: | |
| candidate_companies = list(range(n_companies)) | |
| owners = np.random.choice(candidate_companies, | |
| size=min(n_owners, len(candidate_companies)), | |
| replace=False) | |
| for company_idx in owners: | |
| edges.append([company_idx, patent_idx]) | |
| # 边属性: 所有权份额 | |
| ownership_share = 1.0 / n_owners | |
| # 重要性评分 | |
| importance = np.random.beta(3, 2) | |
| edge_attrs.append([ownership_share, importance]) | |
| return { | |
| 'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(), | |
| 'edge_attr': torch.tensor(edge_attrs, dtype=torch.float) | |
| } | |
| def _generate_trademark_edges(self, n_companies: int, n_trademarks: int, | |
| avg_per_company: float) -> torch.Tensor: | |
| """生成企业-商标持有边""" | |
| edges = [] | |
| # 每个企业持有的商标数量 (泊松分布) | |
| for company_idx in range(n_companies): | |
| n_tm = np.random.poisson(avg_per_company) | |
| if n_tm > 0: | |
| trademarks = np.random.choice(n_trademarks, | |
| size=min(n_tm, n_trademarks), | |
| replace=False) | |
| for tm_idx in trademarks: | |
| edges.append([company_idx, tm_idx]) | |
| return torch.tensor(edges, dtype=torch.long).t().contiguous() | |
| def _generate_invention_edges(self, n_persons: int, n_patents: int, | |
| avg_inventors_per_patent: float) -> Dict: | |
| """生成人员-专利发明边""" | |
| edges = [] | |
| edge_attrs = [] | |
| for patent_idx in range(n_patents): | |
| n_inventors = max(1, int(np.random.poisson(avg_inventors_per_patent))) | |
| n_inventors = min(n_inventors, 10) | |
| inventors = np.random.choice(n_persons, size=n_inventors, replace=False) | |
| for rank, person_idx in enumerate(inventors): | |
| edges.append([person_idx, patent_idx]) | |
| # 边属性: 发明人排序 (第一发明人=1.0) | |
| inventor_rank = 1.0 - (rank / n_inventors) | |
| # 贡献度 | |
| contribution = np.random.beta(5, 2) if rank == 0 else np.random.beta(2, 5) | |
| edge_attrs.append([inventor_rank, contribution]) | |
| return { | |
| 'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(), | |
| 'edge_attr': torch.tensor(edge_attrs, dtype=torch.float) | |
| } | |
| def _generate_employment_edges(self, n_companies: int, n_persons: int, | |
| company_patent_edges: torch.Tensor, | |
| person_patent_edges: torch.Tensor) -> torch.Tensor: | |
| """生成企业-人员雇佣边 (基于专利合作推断)""" | |
| edges = set() | |
| # 基于共同专利推断雇佣关系 | |
| company_patents = defaultdict(set) | |
| person_patents = defaultdict(set) | |
| for company, patent in company_patent_edges.t().tolist(): | |
| company_patents[company].add(patent) | |
| for person, patent in person_patent_edges.t().tolist(): | |
| person_patents[person].add(patent) | |
| for person_idx in range(n_persons): | |
| person_patent_set = person_patents[person_idx] | |
| if not person_patent_set: | |
| continue | |
| # 找到拥有相同专利的企业 | |
| candidate_companies = [] | |
| for company_idx in range(n_companies): | |
| overlap = len(company_patents[company_idx] & person_patent_set) | |
| if overlap > 0: | |
| candidate_companies.append((company_idx, overlap)) | |
| if candidate_companies: | |
| # 选择专利重叠最多的企业 | |
| candidate_companies.sort(key=lambda x: x[1], reverse=True) | |
| employer = candidate_companies[0][0] | |
| edges.add((employer, person_idx)) | |
| else: | |
| # 随机分配雇主 | |
| employer = np.random.randint(0, n_companies) | |
| edges.add((employer, person_idx)) | |
| return torch.tensor(list(edges), dtype=torch.long).t().contiguous() | |
| def _generate_citation_edges(self, n_patents: int, patent_years: torch.Tensor, | |
| avg_citations: float) -> Dict: | |
| """生成专利引用边 (只能引用更早的专利)""" | |
| edges = [] | |
| edge_attrs = [] | |
| for patent_idx in range(n_patents): | |
| current_year = patent_years[patent_idx].item() | |
| # 可引用的专利 (更早的专利) | |
| earlier_patents = (patent_years < current_year).nonzero(as_tuple=True)[0] | |
| if len(earlier_patents) > 0: | |
| n_citations = np.random.poisson(avg_citations) | |
| n_citations = min(n_citations, len(earlier_patents)) | |
| if n_citations > 0: | |
| # 偏向引用更近期的专利 | |
| weights = 1.0 / (current_year - patent_years[earlier_patents].float() + 1) | |
| weights = weights / weights.sum() | |
| cited_patents = np.random.choice( | |
| earlier_patents.numpy(), | |
| size=n_citations, | |
| replace=False, | |
| p=weights.numpy() | |
| ) | |
| for cited_idx in cited_patents: | |
| edges.append([patent_idx, cited_idx]) | |
| # 引用重要性 | |
| citation_importance = np.random.beta(3, 2) | |
| # 时间差 | |
| time_diff = (current_year - patent_years[cited_idx].item()) / 10.0 | |
| edge_attrs.append([citation_importance, time_diff]) | |
| return { | |
| 'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(), | |
| 'edge_attr': torch.tensor(edge_attrs, dtype=torch.float) | |
| } | |
| def _generate_cooperation_edges(self, n_companies: int, | |
| company_industry: torch.Tensor, | |
| company_district: torch.Tensor) -> Dict: | |
| """生成企业合作边""" | |
| edges = [] | |
| edge_attrs = [] | |
| for company_idx in range(n_companies): | |
| # 每个企业0-5个合作伙伴 | |
| n_partners = np.random.poisson(2) | |
| n_partners = min(n_partners, 5) | |
| if n_partners > 0: | |
| # 优先选择同产业或同地区的企业 | |
| same_industry = (company_industry == company_industry[company_idx]).nonzero(as_tuple=True)[0] | |
| same_district = (company_district == company_district[company_idx]).nonzero(as_tuple=True)[0] | |
| # 去掉自己 | |
| same_industry = same_industry[same_industry != company_idx] | |
| same_district = same_district[same_district != company_idx] | |
| # 70% 同产业, 30% 跨产业 | |
| candidates = [] | |
| if len(same_industry) > 0 and random.random() < 0.7: | |
| candidates = same_industry.tolist() | |
| else: | |
| candidates = list(range(n_companies)) | |
| candidates.remove(company_idx) | |
| if candidates: | |
| partners = np.random.choice(candidates, | |
| size=min(n_partners, len(candidates)), | |
| replace=False) | |
| for partner_idx in partners: | |
| # 避免重复边 | |
| if company_idx < partner_idx: | |
| edges.append([company_idx, partner_idx]) | |
| # 合作强度 | |
| strength = np.random.beta(3, 3) | |
| # 合作时长 (年) | |
| duration = np.random.exponential(2) | |
| duration = min(duration, 10) / 10 | |
| edge_attrs.append([strength, duration]) | |
| return { | |
| 'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(), | |
| 'edge_attr': torch.tensor(edge_attrs, dtype=torch.float) | |
| } | |
| def _generate_collaboration_edges(self, n_companies: int, n_institutions: int, | |
| company_industry: torch.Tensor) -> torch.Tensor: | |
| """生成企业-机构合作边""" | |
| edges = [] | |
| for company_idx in range(n_companies): | |
| # 30% 的企业与机构有合作 | |
| if random.random() < 0.3: | |
| n_collabs = np.random.poisson(1) + 1 | |
| n_collabs = min(n_collabs, 3) | |
| institutions = np.random.choice(n_institutions, size=n_collabs, replace=False) | |
| for inst_idx in institutions: | |
| edges.append([company_idx, inst_idx]) | |
| return torch.tensor(edges, dtype=torch.long).t().contiguous() | |
| def _generate_institution_patent_edges(self, n_institutions: int, | |
| n_patents: int) -> torch.Tensor: | |
| """生成机构-专利产出边""" | |
| edges = [] | |
| # 每个机构产出多个专利 | |
| for inst_idx in range(n_institutions): | |
| n_patents_produced = np.random.poisson(20) | |
| n_patents_produced = min(n_patents_produced, n_patents // 2) | |
| if n_patents_produced > 0: | |
| patents = np.random.choice(n_patents, size=n_patents_produced, replace=False) | |
| for patent_idx in patents: | |
| edges.append([inst_idx, patent_idx]) | |
| return torch.tensor(edges, dtype=torch.long).t().contiguous() | |
| def _industry_distribution(self) -> np.ndarray: | |
| """产业分布 (香港实际情况: 金融科技和AI占比较大)""" | |
| weights = np.array([0.25, 0.15, 0.20, 0.10, 0.08, 0.10, 0.07, 0.05]) | |
| return weights / weights.sum() | |
| def _select_district(self, industry: str) -> int: | |
| """根据产业选择地区""" | |
| district_map = { | |
| '金融科技': ['中环', '湾仔', '尖沙咀'], | |
| '生物医药': ['科学园', '将军澳工业邨'], | |
| '人工智能': ['科学园', '数码港'], | |
| '半导体': ['科学园', '将军澳工业邨'], | |
| '新能源': ['科学园', '观塘'], | |
| '电子商务': ['数码港', '荃湾'], | |
| '物流科技': ['观塘', '荃湾'], | |
| '智能制造': ['将军澳工业邨', '观塘'] | |
| } | |
| preferred = district_map.get(industry, self.districts) | |
| district = random.choice(preferred) | |
| return self.districts.index(district) | |
| def _print_statistics(self, data: HeteroData): | |
| """打印数据统计信息""" | |
| print("\n" + "=" * 60) | |
| print("📊 数据集统计") | |
| print("=" * 60) | |
| print(f"\n【节点数量】") | |
| for node_type in data.node_types: | |
| print(f" {node_type:15s}: {data[node_type].num_nodes:6d} 节点") | |
| print(f"\n【边数量】") | |
| for edge_type in data.edge_types: | |
| src, rel, dst = edge_type | |
| num_edges = data[edge_type].num_edges | |
| print(f" {src:12s} --[{rel:12s}]--> {dst:12s}: {num_edges:6d} 条边") | |
| print(f"\n【特征维度】") | |
| for node_type in data.node_types: | |
| print(f" {node_type:15s}: {data[node_type].num_features} 维") | |
| print("\n" + "=" * 60) | |
| def load_data(dataset_size='medium', data_dir='data'): | |
| """ | |
| 加载已保存的数据集 | |
| 参数: | |
| dataset_size: 数据集规模 ('test', 'medium', 'large') | |
| data_dir: 数据目录路径 | |
| 返回: | |
| HeteroData: 加载的异构图数据 | |
| """ | |
| import os | |
| filename_map = { | |
| 'test': 'hk_ip_test.pt', | |
| 'medium': 'hk_ip_medium.pt', | |
| 'large': 'hk_ip_large.pt' | |
| } | |
| if dataset_size not in filename_map: | |
| raise ValueError(f"不支持的数据集规模: {dataset_size}. 支持的选项: {list(filename_map.keys())}") | |
| filepath = os.path.join(data_dir, filename_map[dataset_size]) | |
| if not os.path.exists(filepath): | |
| print(f"❌ 数据文件不存在: {filepath}") | |
| print(f"🔄 正在生成 {dataset_size} 规模的数据集...") | |
| # 自动生成缺失的数据集 | |
| generator = IPEcosystemGenerator(seed=42) | |
| size_configs = { | |
| 'test': { | |
| 'n_companies': 100, | |
| 'n_patents': 500, | |
| 'n_trademarks': 200, | |
| 'n_persons': 300, | |
| 'n_institutions': 20, | |
| 'time_span_years': 5 | |
| }, | |
| 'medium': { | |
| 'n_companies': 500, | |
| 'n_patents': 3000, | |
| 'n_trademarks': 1500, | |
| 'n_persons': 2000, | |
| 'n_institutions': 50, | |
| 'time_span_years': 10 | |
| }, | |
| 'large': { | |
| 'n_companies': 2000, | |
| 'n_patents': 15000, | |
| 'n_trademarks': 8000, | |
| 'n_persons': 10000, | |
| 'n_institutions': 200, | |
| 'time_span_years': 15 | |
| } | |
| } | |
| config = size_configs[dataset_size] | |
| data = generator.generate(**config) | |
| # 确保目录存在 | |
| os.makedirs(data_dir, exist_ok=True) | |
| torch.save(data, filepath) | |
| print(f"✅ 数据已生成并保存到: {filepath}") | |
| return data | |
| try: | |
| print(f"📂 加载数据集: {filepath}") | |
| data = torch.load(filepath, weights_only=False) | |
| print(f"✅ 成功加载 {dataset_size} 规模数据集") | |
| # 打印数据集信息 | |
| print(f"\n📊 数据集概览:") | |
| for node_type in data.node_types: | |
| print(f" {node_type:15s}: {data[node_type].num_nodes:6d} 节点") | |
| total_edges = sum(data[edge_type].num_edges for edge_type in data.edge_types) | |
| print(f" {'总边数':15s}: {total_edges:6d} 条边") | |
| return data | |
| except Exception as e: | |
| print(f"❌ 加载数据失败: {e}") | |
| print(f"🔄 尝试重新生成数据...") | |
| # 删除损坏的文件 | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| # 递归调用重新生成 | |
| return IPEcosystemGenerator.load_data(dataset_size, data_dir) | |
| def list_available_datasets(data_dir='data'): | |
| """ | |
| 列出可用的数据集 | |
| 参数: | |
| data_dir: 数据目录路径 | |
| """ | |
| import os | |
| if not os.path.exists(data_dir): | |
| print(f"📁 数据目录不存在: {data_dir}") | |
| return | |
| filename_map = { | |
| 'hk_ip_test.pt': 'test (小规模)', | |
| 'hk_ip_medium.pt': 'medium (中等规模)', | |
| 'hk_ip_large.pt': 'large (大规模)' | |
| } | |
| print(f"\n📋 可用数据集 (目录: {data_dir}):") | |
| print("-" * 50) | |
| found_any = False | |
| for filename, description in filename_map.items(): | |
| filepath = os.path.join(data_dir, filename) | |
| if os.path.exists(filepath): | |
| file_size = os.path.getsize(filepath) / (1024 * 1024) # MB | |
| print(f"✅ {description:20s} - {filename:20s} ({file_size:.1f} MB)") | |
| found_any = True | |
| else: | |
| print(f"❌ {description:20s} - {filename:20s} (不存在)") | |
| if not found_any: | |
| print("🚫 未找到任何数据集文件") | |
| print("💡 运行 data_generator.py 来生成数据集") | |
| print("-" * 50) | |
| def get_dataset_info(dataset_size='medium', data_dir='data'): | |
| """ | |
| 获取数据集详细信息而不加载数据 | |
| 参数: | |
| dataset_size: 数据集规模 | |
| data_dir: 数据目录路径 | |
| """ | |
| filename_map = { | |
| 'test': 'hk_ip_test.pt', | |
| 'medium': 'hk_ip_medium.pt', | |
| 'large': 'hk_ip_large.pt' | |
| } | |
| filepath = os.path.join(data_dir, filename_map[dataset_size]) | |
| if not os.path.exists(filepath): | |
| print(f"❌ 数据文件不存在: {filepath}") | |
| return None | |
| try: | |
| # 只加载metadata而不加载完整数据 | |
| data = torch.load(filepath, map_location='cpu') | |
| info = { | |
| 'dataset_size': dataset_size, | |
| 'filepath': filepath, | |
| 'file_size_mb': os.path.getsize(filepath) / (1024 * 1024), | |
| 'node_types': data.node_types, | |
| 'edge_types': data.edge_types, | |
| 'node_counts': {ntype: data[ntype].num_nodes for ntype in data.node_types}, | |
| 'edge_counts': {etype: data[etype].num_edges for etype in data.edge_types}, | |
| 'total_nodes': sum(data[ntype].num_nodes for ntype in data.node_types), | |
| 'total_edges': sum(data[etype].num_edges for etype in data.edge_types) | |
| } | |
| return info | |
| except Exception as e: | |
| print(f"❌ 读取数据集信息失败: {e}") | |
| return None | |
| # ========== 使用示例 (更新) ========== | |
| if __name__ == "__main__": | |
| # 生成数据集 | |
| generator = IPEcosystemGenerator(seed=42) | |
| # 确保data目录存在 | |
| import os | |
| os.makedirs('../data', exist_ok=True) | |
| # 检查现有数据集 | |
| print("🔍 检查现有数据集...") | |
| IPEcosystemGenerator.list_available_datasets() | |
| # 小规模测试 | |
| print("\n🧪 生成测试数据集...") | |
| test_data = generator.generate( | |
| n_companies=100, | |
| n_patents=500, | |
| n_trademarks=200, | |
| n_persons=300, | |
| n_institutions=20, | |
| time_span_years=5 | |
| ) | |
| # 中等规模 | |
| print("\n🏗️ 生成中等规模数据集...") | |
| medium_data = generator.generate( | |
| n_companies=500, | |
| n_patents=3000, | |
| n_trademarks=1500, | |
| n_persons=2000, | |
| n_institutions=50, | |
| time_span_years=10 | |
| ) | |
| # 大规模 | |
| print("\n🚀 生成大规模数据集...") | |
| large_data = generator.generate( | |
| n_companies=2000, | |
| n_patents=15000, | |
| n_trademarks=8000, | |
| n_persons=10000, | |
| n_institutions=200, | |
| time_span_years=15 | |
| ) | |
| # 保存数据 | |
| torch.save(test_data, '../data/hk_ip_test.pt') | |
| torch.save(medium_data, '../data/hk_ip_medium.pt') | |
| torch.save(large_data, '../data/hk_ip_large.pt') | |
| print("\n✅ 数据已保存到 data/ 目录") | |
| # 显示最终状态 | |
| print("\n" + "=" * 60) | |
| IPEcosystemGenerator.list_available_datasets() | |
| # 测试加载功能 | |
| print("\n🧪 测试数据加载功能...") | |
| loaded_data = IPEcosystemGenerator.load_data('medium') | |
| print("✅ 数据加载测试成功!") |