GNN / utils /data_generator.py
Huxxshadow's picture
Upload 10 files
f9f7f3b verified
# data_generator.py
import torch
import numpy as np
from torch_geometric.data import HeteroData
from datetime import datetime, timedelta
import random
from typing import Dict, List, Tuple
from collections import defaultdict
class IPEcosystemGenerator:
"""香港知识产权生态网络数据生成器"""
def __init__(self, seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# 香港真实产业分布
self.industries = [
'金融科技', '生物医药', '人工智能', '半导体',
'新能源', '电子商务', '物流科技', '智能制造'
]
# 香港地区分布
self.districts = [
'中环', '湾仔', '尖沙咀', '观塘', '荃湾',
'科学园', '数码港', '将军澳工业邨'
]
# IPC分类 (国际专利分类)
self.ipc_classes = {
'金融科技': ['G06Q', 'G06F', 'H04L'],
'生物医药': ['A61K', 'C07D', 'A61P', 'C12N'],
'人工智能': ['G06N', 'G06K', 'G06T'],
'半导体': ['H01L', 'H01S', 'G11C'],
'新能源': ['H01M', 'H02J', 'F03D'],
'电子商务': ['G06Q', 'H04W', 'G06F'],
'物流科技': ['G01S', 'B65G', 'G06Q'],
'智能制造': ['B25J', 'G05B', 'B23Q']
}
def generate(self,
n_companies: int = 500,
n_patents: int = 3000,
n_trademarks: int = 1500,
n_persons: int = 2000,
n_institutions: int = 50,
time_span_years: int = 10) -> HeteroData:
"""
生成异构图数据
参数:
n_companies: 企业数量
n_patents: 专利数量
n_trademarks: 商标数量
n_persons: 人员数量(发明人/代理人)
n_institutions: 机构数量
time_span_years: 时间跨度(年)
"""
data = HeteroData()
# ========== 生成节点 ==========
print("🏢 生成企业节点...")
companies = self._generate_companies(n_companies)
data['company'].x = companies['features']
data['company'].company_id = companies['ids']
data['company'].industry = companies['industry']
print("📜 生成专利节点...")
patents = self._generate_patents(n_patents, time_span_years)
data['patent'].x = patents['features']
data['patent'].patent_id = patents['ids']
data['patent'].year = patents['year']
print("™️ 生成商标节点...")
trademarks = self._generate_trademarks(n_trademarks, time_span_years)
data['trademark'].x = trademarks['features']
data['trademark'].trademark_id = trademarks['ids']
print("👤 生成人员节点...")
persons = self._generate_persons(n_persons)
data['person'].x = persons['features']
data['person'].person_id = persons['ids']
print("🏛️ 生成机构节点...")
institutions = self._generate_institutions(n_institutions)
data['institution'].x = institutions['features']
data['institution'].institution_id = institutions['ids']
# ========== 生成边关系 ==========
print("\n🔗 生成关系边...")
# 1. 企业 → 专利 (owns)
company_patent_edges = self._generate_ownership_edges(
n_companies, n_patents, companies['industry'], patents['ipc']
)
data['company', 'owns', 'patent'].edge_index = company_patent_edges['edge_index']
data['company', 'owns', 'patent'].edge_attr = company_patent_edges['edge_attr']
# 2. 企业 → 商标 (holds)
company_trademark_edges = self._generate_trademark_edges(
n_companies, n_trademarks, avg_per_company=3
)
data['company', 'holds', 'trademark'].edge_index = company_trademark_edges
# 3. 人员 → 专利 (invents)
person_patent_edges = self._generate_invention_edges(
n_persons, n_patents, avg_inventors_per_patent=2.5
)
data['person', 'invents', 'patent'].edge_index = person_patent_edges['edge_index']
data['person', 'invents', 'patent'].edge_attr = person_patent_edges['edge_attr']
# 4. 企业 → 人员 (employs)
company_person_edges = self._generate_employment_edges(
n_companies, n_persons, company_patent_edges['edge_index'],
person_patent_edges['edge_index']
)
data['company', 'employs', 'person'].edge_index = company_person_edges
# 5. 专利 → 专利 (cites)
patent_citation_edges = self._generate_citation_edges(
n_patents, patents['year'], avg_citations=5
)
data['patent', 'cites', 'patent'].edge_index = patent_citation_edges['edge_index']
data['patent', 'cites', 'patent'].edge_attr = patent_citation_edges['edge_attr']
# 6. 企业 → 企业 (cooperates)
company_cooperation_edges = self._generate_cooperation_edges(
n_companies, companies['industry'], companies['district']
)
data['company', 'cooperates', 'company'].edge_index = company_cooperation_edges['edge_index']
data['company', 'cooperates', 'company'].edge_attr = company_cooperation_edges['edge_attr']
# 7. 企业 → 机构 (collaborates)
company_institution_edges = self._generate_collaboration_edges(
n_companies, n_institutions, companies['industry']
)
data['company', 'collaborates', 'institution'].edge_index = company_institution_edges
# 8. 机构 → 专利 (produces)
institution_patent_edges = self._generate_institution_patent_edges(
n_institutions, n_patents
)
data['institution', 'produces', 'patent'].edge_index = institution_patent_edges
# ========== 添加反向边 ==========
data['patent', 'owned_by', 'company'].edge_index = data['company', 'owns', 'patent'].edge_index.flip(0)
data['trademark', 'held_by', 'company'].edge_index = data['company', 'holds', 'trademark'].edge_index.flip(0)
data['patent', 'invented_by', 'person'].edge_index = data['person', 'invents', 'patent'].edge_index.flip(0)
data['person', 'employed_by', 'company'].edge_index = data['company', 'employs', 'person'].edge_index.flip(0)
data['institution', 'collaborated_by', 'company'].edge_index = data[
'company', 'collaborates', 'institution'].edge_index.flip(0)
data['patent', 'produced_by', 'institution'].edge_index = data[
'institution', 'produces', 'patent'].edge_index.flip(0)
print("\n✅ 数据生成完成!")
self._print_statistics(data)
return data
def _generate_companies(self, n: int) -> Dict:
"""生成企业节点特征"""
features = []
industries = []
districts = []
for i in range(n):
# 企业规模分布 (幂律分布)
size = np.random.pareto(2) * 50 + 10 # 10-500人
size = min(size, 500)
# 成立年限 (0-30年)
age = np.random.exponential(8)
age = min(age, 30)
# 研发投入比例 (0-30%)
rd_ratio = np.random.beta(2, 5) * 0.3
# 产业选择 (有聚类效应)
industry_idx = np.random.choice(len(self.industries),
p=self._industry_distribution())
industry = self.industries[industry_idx]
industries.append(industry_idx)
# 地区选择 (产业相关)
district_idx = self._select_district(industry)
districts.append(district_idx)
# 国际化程度 (0-1)
international = np.random.beta(2, 5)
# 创新能力评分 (0-100)
innovation_score = rd_ratio * 200 + np.random.normal(50, 15)
innovation_score = np.clip(innovation_score, 0, 100)
# 年营收 (对数正态分布,单位:百万港币)
revenue = np.random.lognormal(3, 1.5)
feature = [
size / 500, # 归一化
age / 30,
rd_ratio / 0.3,
international,
innovation_score / 100,
np.log1p(revenue) / 10,
industry_idx / len(self.industries),
district_idx / len(self.districts)
]
features.append(feature)
return {
'features': torch.tensor(features, dtype=torch.float),
'ids': [f'HK-CO-{i:05d}' for i in range(n)],
'industry': torch.tensor(industries, dtype=torch.long),
'district': torch.tensor(districts, dtype=torch.long)
}
def _generate_patents(self, n: int, time_span: int) -> Dict:
"""生成专利节点特征"""
features = []
years = []
ipcs = []
end_year = 2025
start_year = end_year - time_span
for i in range(n):
# 申请年份 (近年增长趋势)
year_prob = np.linspace(0.5, 2.0, time_span)
year_prob = year_prob / year_prob.sum()
year = np.random.choice(range(start_year, end_year), p=year_prob)
years.append(year)
# IPC分类
industry = random.choice(self.industries)
ipc = random.choice(self.ipc_classes[industry])
ipcs.append(ipc)
# 权利要求数量 (1-30)
claims = np.random.poisson(10) + 1
claims = min(claims, 30)
# 发明人数量 (1-10)
inventors = np.random.poisson(2) + 1
inventors = min(inventors, 10)
# 引用数量 (随时间增长)
age = end_year - year
citations = np.random.poisson(age * 0.8)
# 技术领域宽度 (1-5个IPC小类)
tech_breadth = np.random.poisson(2) + 1
tech_breadth = min(tech_breadth, 5)
# 专利价值评分 (0-100)
value_score = (claims * 2 + citations * 5 + tech_breadth * 10) / 2
value_score = min(value_score, 100)
# 是否获得授权 (时间越长概率越高)
grant_prob = 0.3 + min(age * 0.1, 0.6)
is_granted = 1.0 if random.random() < grant_prob else 0.0
feature = [
(year - start_year) / time_span,
claims / 30,
inventors / 10,
citations / 50,
tech_breadth / 5,
value_score / 100,
is_granted,
hash(ipc) % 100 / 100 # IPC编码
]
features.append(feature)
return {
'features': torch.tensor(features, dtype=torch.float),
'ids': [f'HK-PT-{i:06d}' for i in range(n)],
'year': torch.tensor(years, dtype=torch.long),
'ipc': ipcs
}
def _generate_trademarks(self, n: int, time_span: int) -> Dict:
"""生成商标节点特征"""
features = []
end_year = 2025
start_year = end_year - time_span
for i in range(n):
# 注册年份
year = np.random.choice(range(start_year, end_year))
# 商标类别 (1-45类)
nice_class = np.random.randint(1, 46)
# 续展次数 (0-3)
renewals = min(np.random.poisson(0.5), 3)
# 商标类型: 文字/图形/组合
tm_type = np.random.choice([0, 1, 2], p=[0.4, 0.3, 0.3])
# 知名度评分 (0-100)
fame_score = np.random.beta(2, 8) * 100
# 争议记录 (0-10)
disputes = np.random.poisson(0.3)
disputes = min(disputes, 10)
feature = [
(year - start_year) / time_span,
nice_class / 45,
renewals / 3,
tm_type / 2,
fame_score / 100,
disputes / 10
]
features.append(feature)
return {
'features': torch.tensor(features, dtype=torch.float),
'ids': [f'HK-TM-{i:06d}' for i in range(n)]
}
def _generate_persons(self, n: int) -> Dict:
"""生成人员节点特征"""
features = []
for i in range(n):
# 学历 (0: 本科, 1: 硕士, 2: 博士)
education = np.random.choice([0, 1, 2], p=[0.3, 0.5, 0.2])
# 工作年限 (0-40年)
experience = np.random.gamma(3, 3)
experience = min(experience, 40)
# 专利发明数量
patent_count = np.random.negative_binomial(2, 0.3)
patent_count = min(patent_count, 50)
# 技术领域数量 (1-5)
tech_fields = np.random.poisson(2) + 1
tech_fields = min(tech_fields, 5)
# H指数 (科研影响力)
h_index = np.random.poisson(education * 5 + 2)
h_index = min(h_index, 50)
# 跨界合作能力 (0-1)
collaboration = np.random.beta(3, 3)
feature = [
education / 2,
experience / 40,
patent_count / 50,
tech_fields / 5,
h_index / 50,
collaboration
]
features.append(feature)
return {
'features': torch.tensor(features, dtype=torch.float),
'ids': [f'HK-PS-{i:05d}' for i in range(n)]
}
def _generate_institutions(self, n: int) -> Dict:
"""生成机构节点特征"""
features = []
institution_types = ['大学', '研究所', '孵化器', '政府实验室']
for i in range(n):
# 机构类型
inst_type = np.random.choice(len(institution_types))
# 建立年限
age = np.random.exponential(15)
age = min(age, 100)
# 研究人员数量
researchers = np.random.lognormal(4, 1)
researchers = min(researchers, 1000)
# 年度专利产出
patent_output = np.random.poisson(20)
# 国际排名 (归一化后的倒数)
ranking = np.random.pareto(2) + 1
ranking_score = 1 / ranking
# 产学研合作数量
industry_collab = np.random.poisson(15)
feature = [
inst_type / len(institution_types),
age / 100,
np.log1p(researchers) / np.log1p(1000),
patent_output / 100,
ranking_score,
industry_collab / 50
]
features.append(feature)
return {
'features': torch.tensor(features, dtype=torch.float),
'ids': [f'HK-IN-{i:04d}' for i in range(n)]
}
def _generate_ownership_edges(self, n_companies: int, n_patents: int,
company_industry: torch.Tensor,
patent_ipc: List[str]) -> Dict:
"""生成企业-专利所有权边"""
edges = []
edge_attrs = []
# 每个专利1-3个所有者
for patent_idx in range(n_patents):
n_owners = np.random.choice([1, 2, 3], p=[0.7, 0.25, 0.05])
# 选择同产业或相关产业的企业
patent_ipc_code = patent_ipc[patent_idx]
# 找到相关产业的企业
candidate_companies = []
for industry_idx, industry in enumerate(self.industries):
if patent_ipc_code in self.ipc_classes[industry]:
candidate_companies.extend(
(company_industry == industry_idx).nonzero(as_tuple=True)[0].tolist()
)
if not candidate_companies:
candidate_companies = list(range(n_companies))
owners = np.random.choice(candidate_companies,
size=min(n_owners, len(candidate_companies)),
replace=False)
for company_idx in owners:
edges.append([company_idx, patent_idx])
# 边属性: 所有权份额
ownership_share = 1.0 / n_owners
# 重要性评分
importance = np.random.beta(3, 2)
edge_attrs.append([ownership_share, importance])
return {
'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(),
'edge_attr': torch.tensor(edge_attrs, dtype=torch.float)
}
def _generate_trademark_edges(self, n_companies: int, n_trademarks: int,
avg_per_company: float) -> torch.Tensor:
"""生成企业-商标持有边"""
edges = []
# 每个企业持有的商标数量 (泊松分布)
for company_idx in range(n_companies):
n_tm = np.random.poisson(avg_per_company)
if n_tm > 0:
trademarks = np.random.choice(n_trademarks,
size=min(n_tm, n_trademarks),
replace=False)
for tm_idx in trademarks:
edges.append([company_idx, tm_idx])
return torch.tensor(edges, dtype=torch.long).t().contiguous()
def _generate_invention_edges(self, n_persons: int, n_patents: int,
avg_inventors_per_patent: float) -> Dict:
"""生成人员-专利发明边"""
edges = []
edge_attrs = []
for patent_idx in range(n_patents):
n_inventors = max(1, int(np.random.poisson(avg_inventors_per_patent)))
n_inventors = min(n_inventors, 10)
inventors = np.random.choice(n_persons, size=n_inventors, replace=False)
for rank, person_idx in enumerate(inventors):
edges.append([person_idx, patent_idx])
# 边属性: 发明人排序 (第一发明人=1.0)
inventor_rank = 1.0 - (rank / n_inventors)
# 贡献度
contribution = np.random.beta(5, 2) if rank == 0 else np.random.beta(2, 5)
edge_attrs.append([inventor_rank, contribution])
return {
'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(),
'edge_attr': torch.tensor(edge_attrs, dtype=torch.float)
}
def _generate_employment_edges(self, n_companies: int, n_persons: int,
company_patent_edges: torch.Tensor,
person_patent_edges: torch.Tensor) -> torch.Tensor:
"""生成企业-人员雇佣边 (基于专利合作推断)"""
edges = set()
# 基于共同专利推断雇佣关系
company_patents = defaultdict(set)
person_patents = defaultdict(set)
for company, patent in company_patent_edges.t().tolist():
company_patents[company].add(patent)
for person, patent in person_patent_edges.t().tolist():
person_patents[person].add(patent)
for person_idx in range(n_persons):
person_patent_set = person_patents[person_idx]
if not person_patent_set:
continue
# 找到拥有相同专利的企业
candidate_companies = []
for company_idx in range(n_companies):
overlap = len(company_patents[company_idx] & person_patent_set)
if overlap > 0:
candidate_companies.append((company_idx, overlap))
if candidate_companies:
# 选择专利重叠最多的企业
candidate_companies.sort(key=lambda x: x[1], reverse=True)
employer = candidate_companies[0][0]
edges.add((employer, person_idx))
else:
# 随机分配雇主
employer = np.random.randint(0, n_companies)
edges.add((employer, person_idx))
return torch.tensor(list(edges), dtype=torch.long).t().contiguous()
def _generate_citation_edges(self, n_patents: int, patent_years: torch.Tensor,
avg_citations: float) -> Dict:
"""生成专利引用边 (只能引用更早的专利)"""
edges = []
edge_attrs = []
for patent_idx in range(n_patents):
current_year = patent_years[patent_idx].item()
# 可引用的专利 (更早的专利)
earlier_patents = (patent_years < current_year).nonzero(as_tuple=True)[0]
if len(earlier_patents) > 0:
n_citations = np.random.poisson(avg_citations)
n_citations = min(n_citations, len(earlier_patents))
if n_citations > 0:
# 偏向引用更近期的专利
weights = 1.0 / (current_year - patent_years[earlier_patents].float() + 1)
weights = weights / weights.sum()
cited_patents = np.random.choice(
earlier_patents.numpy(),
size=n_citations,
replace=False,
p=weights.numpy()
)
for cited_idx in cited_patents:
edges.append([patent_idx, cited_idx])
# 引用重要性
citation_importance = np.random.beta(3, 2)
# 时间差
time_diff = (current_year - patent_years[cited_idx].item()) / 10.0
edge_attrs.append([citation_importance, time_diff])
return {
'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(),
'edge_attr': torch.tensor(edge_attrs, dtype=torch.float)
}
def _generate_cooperation_edges(self, n_companies: int,
company_industry: torch.Tensor,
company_district: torch.Tensor) -> Dict:
"""生成企业合作边"""
edges = []
edge_attrs = []
for company_idx in range(n_companies):
# 每个企业0-5个合作伙伴
n_partners = np.random.poisson(2)
n_partners = min(n_partners, 5)
if n_partners > 0:
# 优先选择同产业或同地区的企业
same_industry = (company_industry == company_industry[company_idx]).nonzero(as_tuple=True)[0]
same_district = (company_district == company_district[company_idx]).nonzero(as_tuple=True)[0]
# 去掉自己
same_industry = same_industry[same_industry != company_idx]
same_district = same_district[same_district != company_idx]
# 70% 同产业, 30% 跨产业
candidates = []
if len(same_industry) > 0 and random.random() < 0.7:
candidates = same_industry.tolist()
else:
candidates = list(range(n_companies))
candidates.remove(company_idx)
if candidates:
partners = np.random.choice(candidates,
size=min(n_partners, len(candidates)),
replace=False)
for partner_idx in partners:
# 避免重复边
if company_idx < partner_idx:
edges.append([company_idx, partner_idx])
# 合作强度
strength = np.random.beta(3, 3)
# 合作时长 (年)
duration = np.random.exponential(2)
duration = min(duration, 10) / 10
edge_attrs.append([strength, duration])
return {
'edge_index': torch.tensor(edges, dtype=torch.long).t().contiguous(),
'edge_attr': torch.tensor(edge_attrs, dtype=torch.float)
}
def _generate_collaboration_edges(self, n_companies: int, n_institutions: int,
company_industry: torch.Tensor) -> torch.Tensor:
"""生成企业-机构合作边"""
edges = []
for company_idx in range(n_companies):
# 30% 的企业与机构有合作
if random.random() < 0.3:
n_collabs = np.random.poisson(1) + 1
n_collabs = min(n_collabs, 3)
institutions = np.random.choice(n_institutions, size=n_collabs, replace=False)
for inst_idx in institutions:
edges.append([company_idx, inst_idx])
return torch.tensor(edges, dtype=torch.long).t().contiguous()
def _generate_institution_patent_edges(self, n_institutions: int,
n_patents: int) -> torch.Tensor:
"""生成机构-专利产出边"""
edges = []
# 每个机构产出多个专利
for inst_idx in range(n_institutions):
n_patents_produced = np.random.poisson(20)
n_patents_produced = min(n_patents_produced, n_patents // 2)
if n_patents_produced > 0:
patents = np.random.choice(n_patents, size=n_patents_produced, replace=False)
for patent_idx in patents:
edges.append([inst_idx, patent_idx])
return torch.tensor(edges, dtype=torch.long).t().contiguous()
def _industry_distribution(self) -> np.ndarray:
"""产业分布 (香港实际情况: 金融科技和AI占比较大)"""
weights = np.array([0.25, 0.15, 0.20, 0.10, 0.08, 0.10, 0.07, 0.05])
return weights / weights.sum()
def _select_district(self, industry: str) -> int:
"""根据产业选择地区"""
district_map = {
'金融科技': ['中环', '湾仔', '尖沙咀'],
'生物医药': ['科学园', '将军澳工业邨'],
'人工智能': ['科学园', '数码港'],
'半导体': ['科学园', '将军澳工业邨'],
'新能源': ['科学园', '观塘'],
'电子商务': ['数码港', '荃湾'],
'物流科技': ['观塘', '荃湾'],
'智能制造': ['将军澳工业邨', '观塘']
}
preferred = district_map.get(industry, self.districts)
district = random.choice(preferred)
return self.districts.index(district)
def _print_statistics(self, data: HeteroData):
"""打印数据统计信息"""
print("\n" + "=" * 60)
print("📊 数据集统计")
print("=" * 60)
print(f"\n【节点数量】")
for node_type in data.node_types:
print(f" {node_type:15s}: {data[node_type].num_nodes:6d} 节点")
print(f"\n【边数量】")
for edge_type in data.edge_types:
src, rel, dst = edge_type
num_edges = data[edge_type].num_edges
print(f" {src:12s} --[{rel:12s}]--> {dst:12s}: {num_edges:6d} 条边")
print(f"\n【特征维度】")
for node_type in data.node_types:
print(f" {node_type:15s}: {data[node_type].num_features} 维")
print("\n" + "=" * 60)
@staticmethod
def load_data(dataset_size='medium', data_dir='data'):
"""
加载已保存的数据集
参数:
dataset_size: 数据集规模 ('test', 'medium', 'large')
data_dir: 数据目录路径
返回:
HeteroData: 加载的异构图数据
"""
import os
filename_map = {
'test': 'hk_ip_test.pt',
'medium': 'hk_ip_medium.pt',
'large': 'hk_ip_large.pt'
}
if dataset_size not in filename_map:
raise ValueError(f"不支持的数据集规模: {dataset_size}. 支持的选项: {list(filename_map.keys())}")
filepath = os.path.join(data_dir, filename_map[dataset_size])
if not os.path.exists(filepath):
print(f"❌ 数据文件不存在: {filepath}")
print(f"🔄 正在生成 {dataset_size} 规模的数据集...")
# 自动生成缺失的数据集
generator = IPEcosystemGenerator(seed=42)
size_configs = {
'test': {
'n_companies': 100,
'n_patents': 500,
'n_trademarks': 200,
'n_persons': 300,
'n_institutions': 20,
'time_span_years': 5
},
'medium': {
'n_companies': 500,
'n_patents': 3000,
'n_trademarks': 1500,
'n_persons': 2000,
'n_institutions': 50,
'time_span_years': 10
},
'large': {
'n_companies': 2000,
'n_patents': 15000,
'n_trademarks': 8000,
'n_persons': 10000,
'n_institutions': 200,
'time_span_years': 15
}
}
config = size_configs[dataset_size]
data = generator.generate(**config)
# 确保目录存在
os.makedirs(data_dir, exist_ok=True)
torch.save(data, filepath)
print(f"✅ 数据已生成并保存到: {filepath}")
return data
try:
print(f"📂 加载数据集: {filepath}")
data = torch.load(filepath, weights_only=False)
print(f"✅ 成功加载 {dataset_size} 规模数据集")
# 打印数据集信息
print(f"\n📊 数据集概览:")
for node_type in data.node_types:
print(f" {node_type:15s}: {data[node_type].num_nodes:6d} 节点")
total_edges = sum(data[edge_type].num_edges for edge_type in data.edge_types)
print(f" {'总边数':15s}: {total_edges:6d} 条边")
return data
except Exception as e:
print(f"❌ 加载数据失败: {e}")
print(f"🔄 尝试重新生成数据...")
# 删除损坏的文件
if os.path.exists(filepath):
os.remove(filepath)
# 递归调用重新生成
return IPEcosystemGenerator.load_data(dataset_size, data_dir)
@staticmethod
def list_available_datasets(data_dir='data'):
"""
列出可用的数据集
参数:
data_dir: 数据目录路径
"""
import os
if not os.path.exists(data_dir):
print(f"📁 数据目录不存在: {data_dir}")
return
filename_map = {
'hk_ip_test.pt': 'test (小规模)',
'hk_ip_medium.pt': 'medium (中等规模)',
'hk_ip_large.pt': 'large (大规模)'
}
print(f"\n📋 可用数据集 (目录: {data_dir}):")
print("-" * 50)
found_any = False
for filename, description in filename_map.items():
filepath = os.path.join(data_dir, filename)
if os.path.exists(filepath):
file_size = os.path.getsize(filepath) / (1024 * 1024) # MB
print(f"✅ {description:20s} - {filename:20s} ({file_size:.1f} MB)")
found_any = True
else:
print(f"❌ {description:20s} - {filename:20s} (不存在)")
if not found_any:
print("🚫 未找到任何数据集文件")
print("💡 运行 data_generator.py 来生成数据集")
print("-" * 50)
@staticmethod
def get_dataset_info(dataset_size='medium', data_dir='data'):
"""
获取数据集详细信息而不加载数据
参数:
dataset_size: 数据集规模
data_dir: 数据目录路径
"""
filename_map = {
'test': 'hk_ip_test.pt',
'medium': 'hk_ip_medium.pt',
'large': 'hk_ip_large.pt'
}
filepath = os.path.join(data_dir, filename_map[dataset_size])
if not os.path.exists(filepath):
print(f"❌ 数据文件不存在: {filepath}")
return None
try:
# 只加载metadata而不加载完整数据
data = torch.load(filepath, map_location='cpu')
info = {
'dataset_size': dataset_size,
'filepath': filepath,
'file_size_mb': os.path.getsize(filepath) / (1024 * 1024),
'node_types': data.node_types,
'edge_types': data.edge_types,
'node_counts': {ntype: data[ntype].num_nodes for ntype in data.node_types},
'edge_counts': {etype: data[etype].num_edges for etype in data.edge_types},
'total_nodes': sum(data[ntype].num_nodes for ntype in data.node_types),
'total_edges': sum(data[etype].num_edges for etype in data.edge_types)
}
return info
except Exception as e:
print(f"❌ 读取数据集信息失败: {e}")
return None
# ========== 使用示例 (更新) ==========
if __name__ == "__main__":
# 生成数据集
generator = IPEcosystemGenerator(seed=42)
# 确保data目录存在
import os
os.makedirs('../data', exist_ok=True)
# 检查现有数据集
print("🔍 检查现有数据集...")
IPEcosystemGenerator.list_available_datasets()
# 小规模测试
print("\n🧪 生成测试数据集...")
test_data = generator.generate(
n_companies=100,
n_patents=500,
n_trademarks=200,
n_persons=300,
n_institutions=20,
time_span_years=5
)
# 中等规模
print("\n🏗️ 生成中等规模数据集...")
medium_data = generator.generate(
n_companies=500,
n_patents=3000,
n_trademarks=1500,
n_persons=2000,
n_institutions=50,
time_span_years=10
)
# 大规模
print("\n🚀 生成大规模数据集...")
large_data = generator.generate(
n_companies=2000,
n_patents=15000,
n_trademarks=8000,
n_persons=10000,
n_institutions=200,
time_span_years=15
)
# 保存数据
torch.save(test_data, '../data/hk_ip_test.pt')
torch.save(medium_data, '../data/hk_ip_medium.pt')
torch.save(large_data, '../data/hk_ip_large.pt')
print("\n✅ 数据已保存到 data/ 目录")
# 显示最终状态
print("\n" + "=" * 60)
IPEcosystemGenerator.list_available_datasets()
# 测试加载功能
print("\n🧪 测试数据加载功能...")
loaded_data = IPEcosystemGenerator.load_data('medium')
print("✅ 数据加载测试成功!")