GNN / app.py
Huxxshadow's picture
Rename app4.py to app.py
47eb12a verified
# app4.py (增强版)
import gradio as gr
import torch
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import networkx as nx
from datetime import datetime
import json
import os
from collections import defaultdict
from utils.data_generator import IPEcosystemGenerator
from model.HGT import HGT, HGTLinkPredictor, HGTNodeClassifier, HGTPatentValuePredictor, HGTCollaborationRecommender
from model.HeteroGNN import HeteroGNN, LinkPredictor, NodeClassifier
# app_enhanced.py (数据科学增强版)
import gradio as gr
import torch
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import networkx as nx
from datetime import datetime
from collections import defaultdict
# 数据科学库
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
from sklearn.metrics import silhouette_score, davies_bouldin_score
from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr, spearmanr
from scipy.cluster.hierarchy import dendrogram, linkage
import umap
# 原有导入
from utils.data_generator import IPEcosystemGenerator
from model.HGT import HGT, HGTLinkPredictor, HGTNodeClassifier, HGTPatentValuePredictor, HGTCollaborationRecommender
from model.HeteroGNN import HeteroGNN, LinkPredictor, NodeClassifier
# 全局变量
current_data = None
loaded_models = {}
training_history = defaultdict(list)
embedding_cache = {} # 缓存节点嵌入
# ==================== 数据科学分析模块 ====================
def compute_node_embeddings(force_recompute=False):
"""计算并缓存节点嵌入"""
global current_data, loaded_models, embedding_cache
if current_data is None:
return None, "❌ 请先加载数据集!"
# 如果已缓存且不强制重新计算,直接返回
if embedding_cache and not force_recompute:
return embedding_cache, "✅ 使用缓存的嵌入"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = current_data.to(device)
# 尝试使用已训练的模型
model = None
if 'link_prediction' in loaded_models:
model = loaded_models['link_prediction']['model']
elif 'node_classification' in loaded_models:
model = loaded_models['node_classification']['model']
elif 'patent_value' in loaded_models:
model = loaded_models['patent_value']['model']
if model is None:
# 如果没有训练好的模型,使用原始特征
embeddings = {}
for node_type in data.node_types:
embeddings[node_type] = data[node_type].x.cpu().numpy()
embedding_cache = embeddings
return embeddings, "⚠️ 使用原始特征(建议先训练模型)"
# 使用模型计算嵌入
model.eval()
with torch.no_grad():
x_dict = model(data.x_dict, data.edge_index_dict)
embeddings = {k: v.cpu().numpy() for k, v in x_dict.items()}
embedding_cache = embeddings
return embeddings, "✅ 使用模型嵌入"
def perform_pca_analysis(node_type, n_components=2):
"""执行PCA降维分析"""
embeddings, status = compute_node_embeddings()
if embeddings is None:
return status, None, None, None
if node_type not in embeddings:
return f"❌ 无效的节点类型: {node_type}", None, None, None
X = embeddings[node_type]
# 标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# PCA
pca = PCA(n_components=min(n_components, X.shape[1]))
X_pca = pca.fit_transform(X_scaled)
# 解释方差比例
explained_var = pca.explained_variance_ratio_
cumsum_var = np.cumsum(explained_var)
# 生成报告
report = f"""
## 📊 PCA降维分析报告
**节点类型**: {node_type}
**原始维度**: {X.shape[1]}
**降维后维度**: {n_components}
**样本数量**: {X.shape[0]}
### 主成分解释方差
"""
for i, var in enumerate(explained_var[:10]): # 显示前10个
report += f"- **PC{i + 1}**: {var * 100:.2f}%\n"
report += f"\n**前{n_components}个主成分累计解释方差**: {cumsum_var[n_components - 1] * 100:.2f}%\n"
# 可视化1: 解释方差
fig1 = make_subplots(
rows=1, cols=2,
subplot_titles=('主成分解释方差', '累计解释方差'),
specs=[[{'type': 'bar'}, {'type': 'scatter'}]]
)
fig1.add_trace(
go.Bar(
x=[f'PC{i + 1}' for i in range(len(explained_var))],
y=explained_var * 100,
marker=dict(color=explained_var, colorscale='Viridis'),
name='解释方差'
),
row=1, col=1
)
fig1.add_trace(
go.Scatter(
x=[f'PC{i + 1}' for i in range(len(cumsum_var))],
y=cumsum_var * 100,
mode='lines+markers',
marker=dict(size=8, color='#FF6B6B'),
line=dict(width=3),
name='累计方差'
),
row=1, col=2
)
fig1.update_xaxes(title_text="主成分", row=1, col=1)
fig1.update_xaxes(title_text="主成分", row=1, col=2)
fig1.update_yaxes(title_text="解释方差 (%)", row=1, col=1)
fig1.update_yaxes(title_text="累计解释方差 (%)", row=1, col=2)
fig1.update_layout(
title="PCA方差分析",
template="plotly_white",
height=400,
showlegend=False
)
# 可视化2: PCA投影
if n_components >= 2:
# 获取标签(如果有)
labels = None
label_names = None
if node_type == 'company' and hasattr(current_data['company'], 'industry'):
labels = current_data['company'].industry.cpu().numpy()
label_names = ['金融科技', '生物医药', '人工智能', '半导体',
'新能源', '电子商务', '物流科技', '智能制造']
if n_components == 2:
fig2 = go.Figure()
if labels is not None:
for label_id in np.unique(labels):
mask = labels == label_id
fig2.add_trace(go.Scatter(
x=X_pca[mask, 0],
y=X_pca[mask, 1],
mode='markers',
name=label_names[label_id] if label_names else f'类别{label_id}',
marker=dict(size=6, opacity=0.7),
text=[f'{node_type}-{i}' for i in np.where(mask)[0]],
hoverinfo='text'
))
else:
fig2.add_trace(go.Scatter(
x=X_pca[:, 0],
y=X_pca[:, 1],
mode='markers',
marker=dict(size=6, color=X_pca[:, 0], colorscale='Viridis', opacity=0.7),
text=[f'{node_type}-{i}' for i in range(len(X_pca))],
hoverinfo='text'
))
fig2.update_layout(
title=f"{node_type.upper()} - PCA 2D投影",
xaxis_title=f"PC1 ({explained_var[0] * 100:.1f}%)",
yaxis_title=f"PC2 ({explained_var[1] * 100:.1f}%)",
template="plotly_white",
height=600
)
else: # 3D
fig2 = go.Figure()
if labels is not None:
for label_id in np.unique(labels):
mask = labels == label_id
fig2.add_trace(go.Scatter3d(
x=X_pca[mask, 0],
y=X_pca[mask, 1],
z=X_pca[mask, 2],
mode='markers',
name=label_names[label_id] if label_names else f'类别{label_id}',
marker=dict(size=4, opacity=0.7),
text=[f'{node_type}-{i}' for i in np.where(mask)[0]],
hoverinfo='text'
))
else:
fig2.add_trace(go.Scatter3d(
x=X_pca[:, 0],
y=X_pca[:, 1],
z=X_pca[:, 2],
mode='markers',
marker=dict(size=4, color=X_pca[:, 0], colorscale='Viridis', opacity=0.7),
text=[f'{node_type}-{i}' for i in range(len(X_pca))],
hoverinfo='text'
))
fig2.update_layout(
title=f"{node_type.upper()} - PCA 3D投影",
scene=dict(
xaxis_title=f"PC1 ({explained_var[0] * 100:.1f}%)",
yaxis_title=f"PC2 ({explained_var[1] * 100:.1f}%)",
zaxis_title=f"PC3 ({explained_var[2] * 100:.1f}%)"
),
template="plotly_white",
height=600
)
else:
fig2 = None
# 主成分载荷分析
components_df = pd.DataFrame(
pca.components_[:5].T, # 前5个主成分
columns=[f'PC{i + 1}' for i in range(min(5, pca.n_components_))],
index=[f'Feature{i + 1}' for i in range(X.shape[1])]
)
components_df['Abs_Max'] = components_df.abs().max(axis=1)
components_df = components_df.sort_values('Abs_Max', ascending=False).head(10)
components_df = components_df.drop('Abs_Max', axis=1)
return report, fig1, fig2, components_df
def perform_tsne_analysis(node_type, n_components=2, perplexity=30, n_iter=1000):
"""执行t-SNE降维分析"""
embeddings, status = compute_node_embeddings()
if embeddings is None:
return status, None, None
if node_type not in embeddings:
return f"❌ 无效的节点类型: {node_type}", None, None
X = embeddings[node_type]
# 限制样本数量(t-SNE计算较慢)
max_samples = 2000
if len(X) > max_samples:
indices = np.random.choice(len(X), max_samples, replace=False)
X = X[indices]
sampled = True
else:
indices = np.arange(len(X))
sampled = False
# 标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# t-SNE
tsne = TSNE(
n_components=n_components,
perplexity=min(perplexity, len(X) - 1),
max_iter=n_iter,
random_state=42
)
X_tsne = tsne.fit_transform(X_scaled)
# 生成报告
report = f"""
## 🎯 t-SNE降维分析报告
**节点类型**: {node_type}
**原始维度**: {X.shape[1]}
**降维后维度**: {n_components}
**样本数量**: {len(X)} {'(采样)' if sampled else ''}
**困惑度**: {perplexity}
**迭代次数**: {n_iter}
### t-SNE参数说明
- **困惑度(Perplexity)**: 平衡局部和全局结构,通常在5-50之间
- **迭代次数**: 更多迭代可能得到更好的结果,但计算时间更长
### 应用场景
t-SNE特别适合:
- 可视化高维数据的聚类结构
- 发现数据中的模式和分组
- 识别异常值和离群点
"""
# 获取标签
labels = None
label_names = None
if node_type == 'company' and hasattr(current_data['company'], 'industry'):
labels = current_data['company'].industry.cpu().numpy()[indices]
label_names = ['金融科技', '生物医药', '人工智能', '半导体',
'新能源', '电子商务', '物流科技', '智能制造']
# 可视化
if n_components == 2:
fig = go.Figure()
if labels is not None:
for label_id in np.unique(labels):
mask = labels == label_id
fig.add_trace(go.Scatter(
x=X_tsne[mask, 0],
y=X_tsne[mask, 1],
mode='markers',
name=label_names[label_id] if label_names else f'类别{label_id}',
marker=dict(size=8, opacity=0.7),
text=[f'{node_type}-{indices[i]}' for i in np.where(mask)[0]],
hoverinfo='text'
))
else:
fig.add_trace(go.Scatter(
x=X_tsne[:, 0],
y=X_tsne[:, 1],
mode='markers',
marker=dict(
size=8,
color=np.arange(len(X_tsne)),
colorscale='Viridis',
opacity=0.7
),
text=[f'{node_type}-{i}' for i in indices],
hoverinfo='text'
))
fig.update_layout(
title=f"{node_type.upper()} - t-SNE 2D可视化",
xaxis_title="t-SNE 1",
yaxis_title="t-SNE 2",
template="plotly_white",
height=600
)
else: # 3D
fig = go.Figure()
if labels is not None:
for label_id in np.unique(labels):
mask = labels == label_id
fig.add_trace(go.Scatter3d(
x=X_tsne[mask, 0],
y=X_tsne[mask, 1],
z=X_tsne[mask, 2],
mode='markers',
name=label_names[label_id] if label_names else f'类别{label_id}',
marker=dict(size=5, opacity=0.7),
text=[f'{node_type}-{indices[i]}' for i in np.where(mask)[0]],
hoverinfo='text'
))
else:
fig.add_trace(go.Scatter3d(
x=X_tsne[:, 0],
y=X_tsne[:, 1],
z=X_tsne[:, 2],
mode='markers',
marker=dict(
size=5,
color=np.arange(len(X_tsne)),
colorscale='Viridis',
opacity=0.7
),
text=[f'{node_type}-{i}' for i in indices],
hoverinfo='text'
))
fig.update_layout(
title=f"{node_type.upper()} - t-SNE 3D可视化",
scene=dict(
xaxis_title="t-SNE 1",
yaxis_title="t-SNE 2",
zaxis_title="t-SNE 3"
),
template="plotly_white",
height=600
)
# 如果有标签,计算聚类指标
metrics_df = None
if labels is not None:
try:
silhouette = silhouette_score(X_tsne, labels)
davies_bouldin = davies_bouldin_score(X_tsne, labels)
metrics_df = pd.DataFrame({
'指标': ['轮廓系数', 'Davies-Bouldin指数'],
'数值': [f'{silhouette:.4f}', f'{davies_bouldin:.4f}'],
'说明': [
'[-1, 1],越接近1越好',
'越小越好,表示聚类紧密且分离'
]
})
report += f"\n### 聚类质量评估\n\n"
report += f"- **轮廓系数**: {silhouette:.4f}\n"
report += f"- **Davies-Bouldin指数**: {davies_bouldin:.4f}\n"
except:
pass
return report, fig, metrics_df
def perform_clustering_analysis(node_type, method='kmeans', n_clusters=5):
"""执行聚类分析"""
embeddings, status = compute_node_embeddings()
if embeddings is None:
return status, None, None, None
if node_type not in embeddings:
return f"❌ 无效的节点类型: {node_type}", None, None, None
X = embeddings[node_type]
# 标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 降维到2D用于可视化
pca = PCA(n_components=2)
X_2d = pca.fit_transform(X_scaled)
# 执行聚类
if method == 'kmeans':
clusterer = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
labels = clusterer.fit_predict(X_scaled)
centers_2d = pca.transform(clusterer.cluster_centers_)
elif method == 'dbscan':
clusterer = DBSCAN(eps=0.5, min_samples=5)
labels = clusterer.fit_predict(X_scaled)
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
centers_2d = None
elif method == 'hierarchical':
clusterer = AgglomerativeClustering(n_clusters=n_clusters)
labels = clusterer.fit_predict(X_scaled)
centers_2d = None
# 计算聚类指标
if len(set(labels)) > 1:
silhouette = silhouette_score(X_scaled, labels)
davies_bouldin = davies_bouldin_score(X_scaled, labels)
else:
silhouette = 0
davies_bouldin = 0
# 生成报告
report = f"""
## 🎯 聚类分析报告
**节点类型**: {node_type}
**聚类方法**: {method.upper()}
**聚类数量**: {len(set(labels))}
**样本数量**: {len(X)}
### 聚类质量指标
- **轮廓系数**: {silhouette:.4f}
- **Davies-Bouldin指数**: {davies_bouldin:.4f}
### 各簇样本分布
"""
cluster_counts = pd.Series(labels).value_counts().sort_index()
for cluster_id, count in cluster_counts.items():
cluster_name = f"簇 {cluster_id}" if cluster_id != -1 else "噪声点"
report += f"- **{cluster_name}**: {count} 个样本 ({count / len(labels) * 100:.1f}%)\n"
# 可视化
fig = go.Figure()
for cluster_id in sorted(set(labels)):
mask = labels == cluster_id
cluster_name = f"簇 {cluster_id}" if cluster_id != -1 else "噪声点"
fig.add_trace(go.Scatter(
x=X_2d[mask, 0],
y=X_2d[mask, 1],
mode='markers',
name=cluster_name,
marker=dict(size=8, opacity=0.7),
text=[f'{node_type}-{i}<br>簇: {cluster_id}' for i in np.where(mask)[0]],
hoverinfo='text'
))
# 添加聚类中心(如果有)
if centers_2d is not None:
fig.add_trace(go.Scatter(
x=centers_2d[:, 0],
y=centers_2d[:, 1],
mode='markers',
name='聚类中心',
marker=dict(
size=15,
symbol='x',
color='red',
line=dict(width=2, color='darkred')
)
))
fig.update_layout(
title=f"{node_type.upper()} - {method.upper()} 聚类结果 (PCA投影)",
xaxis_title=f"PC1 ({pca.explained_variance_ratio_[0] * 100:.1f}%)",
yaxis_title=f"PC2 ({pca.explained_variance_ratio_[1] * 100:.1f}%)",
template="plotly_white",
height=600
)
# 聚类统计表
stats_data = []
for cluster_id in sorted(set(labels)):
if cluster_id == -1:
continue
mask = labels == cluster_id
cluster_samples = X[mask]
stats_data.append({
'簇ID': cluster_id,
'样本数': mask.sum(),
'占比': f'{mask.sum() / len(labels) * 100:.1f}%',
'平均值': f'{cluster_samples.mean():.4f}',
'标准差': f'{cluster_samples.std():.4f}'
})
stats_df = pd.DataFrame(stats_data)
return report, fig, stats_df
def perform_correlation_analysis(node_type):
"""执行特征相关性分析"""
if current_data is None:
return "❌ 请先加载数据集!", None, None
X = current_data[node_type].x.cpu().numpy()
# 计算相关系数矩阵
corr_matrix = np.corrcoef(X.T)
# 特征名称
feature_names = [f'F{i + 1}' for i in range(X.shape[1])]
# 热力图
fig1 = go.Figure(data=go.Heatmap(
z=corr_matrix,
x=feature_names,
y=feature_names,
colorscale='RdBu',
zmid=0,
text=np.round(corr_matrix, 2),
texttemplate='%{text}',
textfont={"size": 8},
colorbar=dict(title="相关系数")
))
fig1.update_layout(
title=f"{node_type.upper()} - 特征相关性热力图",
template="plotly_white",
height=600,
width=700
)
# 找出高相关性特征对
high_corr = []
for i in range(len(corr_matrix)):
for j in range(i + 1, len(corr_matrix)):
if abs(corr_matrix[i, j]) > 0.7:
high_corr.append({
'特征1': feature_names[i],
'特征2': feature_names[j],
'相关系数': f'{corr_matrix[i, j]:.4f}',
'类型': '正相关' if corr_matrix[i, j] > 0 else '负相关'
})
high_corr_df = pd.DataFrame(high_corr) if high_corr else pd.DataFrame({
'提示': ['未发现高相关性特征对(|r| > 0.7)']
})
# 特征分布可视化
fig2 = make_subplots(
rows=2, cols=2,
subplot_titles=[f'{feature_names[i]}分布' for i in range(min(4, len(feature_names)))]
)
for idx in range(min(4, X.shape[1])):
row = idx // 2 + 1
col = idx % 2 + 1
fig2.add_trace(
go.Histogram(
x=X[:, idx],
name=feature_names[idx],
nbinsx=50,
marker=dict(color=px.colors.qualitative.Set3[idx])
),
row=row, col=col
)
fig2.update_layout(
title=f"{node_type.upper()} - 特征分布",
template="plotly_white",
height=500,
showlegend=False
)
# 生成报告
report = f"""
## 📊 特征相关性分析报告
**节点类型**: {node_type}
**特征数量**: {X.shape[1]}
**样本数量**: {X.shape[0]}
### 相关性摘要
- **平均相关系数**: {np.mean(np.abs(corr_matrix[np.triu_indices_from(corr_matrix, k=1)])):.4f}
- **最大相关系数**: {np.max(corr_matrix[np.triu_indices_from(corr_matrix, k=1)]):.4f}
- **最小相关系数**: {np.min(corr_matrix[np.triu_indices_from(corr_matrix, k=1)]):.4f}
- **高相关特征对数量**: {len(high_corr)}
### 建议
"""
if len(high_corr) > 0:
report += "⚠️ 发现高相关性特征,可能存在冗余,建议考虑特征选择或降维。\n"
else:
report += "✅ 特征之间相关性较低,特征独立性良好。\n"
return report, fig1, high_corr_df, fig2
def generate_statistics_dashboard():
"""生成统计分析仪表板"""
if current_data is None:
return "❌ 请先加载数据集!", None
# 收集统计信息
stats = {}
for node_type in current_data.node_types:
X = current_data[node_type].x.cpu().numpy()
stats[node_type] = {
'count': len(X),
'features': X.shape[1],
'mean': X.mean(axis=0),
'std': X.std(axis=0),
'min': X.min(axis=0),
'max': X.max(axis=0),
'median': np.median(X, axis=0)
}
# 创建仪表板
fig = make_subplots(
rows=2, cols=2,
subplot_titles=(
'各节点类型数量分布',
'平均特征维度',
'特征均值分布',
'特征标准差分布'
),
specs=[
[{'type': 'bar'}, {'type': 'bar'}],
[{'type': 'box'}, {'type': 'box'}]
]
)
# 1. 节点数量
fig.add_trace(
go.Bar(
x=list(stats.keys()),
y=[s['count'] for s in stats.values()],
marker=dict(color=px.colors.qualitative.Set2),
text=[s['count'] for s in stats.values()],
textposition='outside'
),
row=1, col=1
)
# 2. 特征维度
fig.add_trace(
go.Bar(
x=list(stats.keys()),
y=[s['features'] for s in stats.values()],
marker=dict(color=px.colors.qualitative.Set3),
text=[s['features'] for s in stats.values()],
textposition='outside'
),
row=1, col=2
)
# 3. 特征均值分布
for node_type, stat in stats.items():
fig.add_trace(
go.Box(
y=stat['mean'],
name=node_type,
boxmean='sd'
),
row=2, col=1
)
# 4. 特征标准差分布
for node_type, stat in stats.items():
fig.add_trace(
go.Box(
y=stat['std'],
name=node_type,
boxmean='sd'
),
row=2, col=2
)
fig.update_xaxes(title_text="节点类型", row=1, col=1)
fig.update_xaxes(title_text="节点类型", row=1, col=2)
fig.update_yaxes(title_text="数量", row=1, col=1)
fig.update_yaxes(title_text="特征维度", row=1, col=2)
fig.update_yaxes(title_text="特征均值", row=2, col=1)
fig.update_yaxes(title_text="特征标准差", row=2, col=2)
fig.update_layout(
title="📊 数据集统计分析仪表板",
template="plotly_white",
height=800,
showlegend=True
)
# 生成详细报告
report = """
## 📊 数据集统计分析报告
"""
for node_type, stat in stats.items():
report += f"""
### {node_type.upper()}
- **样本数量**: {stat['count']:,}
- **特征维度**: {stat['features']}
- **特征均值范围**: [{stat['mean'].min():.4f}, {stat['mean'].max():.4f}]
- **特征标准差范围**: [{stat['std'].min():.4f}, {stat['std'].max():.4f}]
- **特征值范围**: [{stat['min'].min():.4f}, {stat['max'].max():.4f}]
"""
# 整体统计
total_nodes = sum(s['count'] for s in stats.values())
total_edges = sum(current_data[et].num_edges for et in current_data.edge_types)
report += f"""
### 整体概览
- **总节点数**: {total_nodes:,}
- **总边数**: {total_edges:,}
- **网络密度**: {total_edges / (total_nodes * (total_nodes - 1)) * 100:.6f}%
- **平均度**: {total_edges * 2 / total_nodes:.2f}
**生成时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
"""
return report, fig
# 全局变量
current_data = None
loaded_models = {}
training_history = defaultdict(list)
# 产业关键词映射
INDUSTRY_KEYWORDS = {
'金融科技': ['金融', '支付', '区块链', '数字货币', '银行', '保险', '证券', '投资', 'fintech'],
'生物医药': ['医药', '生物', '制药', '医疗', '健康', '诊断', '治疗', '基因', '蛋白质'],
'人工智能': ['AI', '机器学习', '深度学习', '神经网络', '计算机视觉', 'NLP', '智能', '算法'],
'半导体': ['芯片', '集成电路', '半导体', '晶圆', 'IC', '处理器', '传感器'],
'新能源': ['太阳能', '风能', '电池', '储能', '充电', '新能源', '清洁能源'],
'电子商务': ['电商', '网购', '在线', '平台', '零售', 'O2O', '跨境'],
'物流科技': ['物流', '配送', '仓储', '供应链', '运输', '快递'],
'智能制造': ['制造', '工业', '自动化', '机器人', '数控', '智能工厂', '工业4.0']
}
# ==================== 数据管理模块 ====================
def generate_dataset(dataset_size, n_companies, n_patents, n_trademarks, n_persons, n_institutions, time_span):
"""生成数据集"""
global current_data
try:
generator = IPEcosystemGenerator(seed=42)
current_data = generator.generate(
n_companies=n_companies,
n_patents=n_patents,
n_trademarks=n_trademarks,
n_persons=n_persons,
n_institutions=n_institutions,
time_span_years=time_span
)
# 保存数据
os.makedirs('data', exist_ok=True)
torch.save(current_data, f'data/custom_{dataset_size}.pt')
stats = get_dataset_stats()
details = get_detailed_stats()
return "✅ 数据集生成成功!", stats, details
except Exception as e:
return f"❌ 生成失败: {str(e)}", None, None
def load_dataset(dataset_size):
"""加载已保存的数据集"""
global current_data
try:
current_data = IPEcosystemGenerator.load_data(dataset_size, 'data')
stats = get_dataset_stats()
details = get_detailed_stats()
return f"✅ 成功加载 {dataset_size} 数据集", stats, details
except Exception as e:
return f"❌ 加载失败: {str(e)}", None, None
def get_dataset_stats():
"""获取数据集统计信息"""
if current_data is None:
return pd.DataFrame({"提示": ["请先生成或加载数据集"]})
stats_data = []
# 节点统计
for node_type in current_data.node_types:
stats_data.append({
"类型": "节点",
"名称": node_type,
"数量": current_data[node_type].num_nodes,
"特征维度": current_data[node_type].num_features
})
# 边统计
for edge_type in current_data.edge_types:
src, rel, dst = edge_type
stats_data.append({
"类型": "关系",
"名称": f"{src}{dst} ({rel})",
"数量": current_data[edge_type].num_edges,
"特征维度": "-"
})
return pd.DataFrame(stats_data)
def get_detailed_stats():
"""获取详细统计信息"""
if current_data is None:
return "请先加载数据集"
report = f"""
## 📊 数据集详细统计报告
### 节点概览
"""
total_nodes = sum(current_data[ntype].num_nodes for ntype in current_data.node_types)
report += f"- **总节点数**: {total_nodes:,}\n\n"
for node_type in current_data.node_types:
num_nodes = current_data[node_type].num_nodes
features = current_data[node_type].num_features
percentage = (num_nodes / total_nodes) * 100
report += f" - **{node_type}**: {num_nodes:,} 个 ({percentage:.1f}%) - {features}维特征\n"
report += "\n### 关系概览\n"
total_edges = sum(current_data[etype].num_edges for etype in current_data.edge_types)
report += f"- **总关系数**: {total_edges:,}\n\n"
# 按关系类型分组
edge_groups = {}
for edge_type in current_data.edge_types:
src, rel, dst = edge_type
if rel not in edge_groups:
edge_groups[rel] = []
edge_groups[rel].append((src, dst, current_data[edge_type].num_edges))
for rel, edges in edge_groups.items():
report += f" **{rel}**:\n"
for src, dst, count in edges:
report += f" - {src}{dst}: {count:,} 条边\n"
report += "\n"
# 数据密度分析
report += "### 数据质量指标\n"
# 企业-专利密度
if ('company', 'owns', 'patent') in current_data.edge_types:
n_companies = current_data['company'].num_nodes
n_patents = current_data['patent'].num_nodes
n_edges = current_data['company', 'owns', 'patent'].num_edges
density = n_edges / (n_companies * n_patents) * 100
avg_patents_per_company = n_edges / n_companies
report += f"- **企业-专利密度**: {density:.4f}%\n"
report += f"- **平均每企业专利数**: {avg_patents_per_company:.1f}\n"
# 专利引用密度
if ('patent', 'cites', 'patent') in current_data.edge_types:
n_citations = current_data['patent', 'cites', 'patent'].num_edges
avg_citations = n_citations / n_patents if n_patents > 0 else 0
report += f"- **平均每专利引用数**: {avg_citations:.1f}\n"
report += f"\n**生成时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
return report
def visualize_network_overview():
"""网络概览可视化"""
if current_data is None:
return None
# 创建节点统计柱状图
node_counts = {ntype: current_data[ntype].num_nodes for ntype in current_data.node_types}
fig = go.Figure()
fig.add_trace(go.Bar(
x=list(node_counts.keys()),
y=list(node_counts.values()),
marker=dict(
color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8'],
line=dict(color='white', width=2)
),
text=list(node_counts.values()),
textposition='outside',
))
fig.update_layout(
title="📊 节点类型分布统计",
xaxis_title="节点类型",
yaxis_title="数量",
template="plotly_white",
height=400,
font=dict(size=14)
)
return fig
def visualize_edge_distribution():
"""边类型分布可视化"""
if current_data is None:
return None
edge_counts = {}
for edge_type in current_data.edge_types:
src, rel, dst = edge_type
edge_counts[f"{src}{dst}"] = current_data[edge_type].num_edges
# 创建饼图
fig = go.Figure(data=[go.Pie(
labels=list(edge_counts.keys()),
values=list(edge_counts.values()),
hole=0.3,
marker=dict(colors=px.colors.qualitative.Set3),
textinfo='label+percent',
textfont=dict(size=12)
)])
fig.update_layout(
title="🔗 关系类型分布",
template="plotly_white",
height=500,
font=dict(size=13)
)
return fig
def visualize_network_graph(node_limit=100):
"""网络图可视化(使用NetworkX和Plotly)"""
if current_data is None:
return None
# 创建NetworkX图
G = nx.Graph()
# 添加企业节点(限制数量以提高性能)
n_companies = min(node_limit, current_data['company'].num_nodes)
for i in range(n_companies):
G.add_node(f"C{i}", node_type='company', size=10, color='#FF6B6B')
# 添加专利节点
n_patents = min(node_limit, current_data['patent'].num_nodes)
for i in range(n_patents):
G.add_node(f"P{i}", node_type='patent', size=5, color='#4ECDC4')
# 添加企业-专利边
edge_index = current_data['company', 'owns', 'patent'].edge_index
for i in range(min(500, edge_index.size(1))):
company_idx = edge_index[0, i].item()
patent_idx = edge_index[1, i].item()
if company_idx < n_companies and patent_idx < n_patents:
G.add_edge(f"C{company_idx}", f"P{patent_idx}")
# 使用Spring布局
pos = nx.spring_layout(G, k=0.5, iterations=50)
# 创建边的trace
edge_trace = go.Scatter(
x=[], y=[],
line=dict(width=0.5, color='#888'),
hoverinfo='none',
mode='lines'
)
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace['x'] += tuple([x0, x1, None])
edge_trace['y'] += tuple([y0, y1, None])
# 创建节点的trace
company_trace = go.Scatter(
x=[], y=[],
mode='markers',
name='企业',
marker=dict(size=10, color='#FF6B6B', line=dict(width=2, color='white')),
text=[],
hoverinfo='text'
)
patent_trace = go.Scatter(
x=[], y=[],
mode='markers',
name='专利',
marker=dict(size=6, color='#4ECDC4', line=dict(width=1, color='white')),
text=[],
hoverinfo='text'
)
for node in G.nodes():
x, y = pos[node]
if node.startswith('C'):
company_trace['x'] += tuple([x])
company_trace['y'] += tuple([y])
company_trace['text'] += tuple([f'企业 {node}'])
else:
patent_trace['x'] += tuple([x])
patent_trace['y'] += tuple([y])
patent_trace['text'] += tuple([f'专利 {node}'])
# 创建图形
fig = go.Figure(data=[edge_trace, company_trace, patent_trace])
fig.update_layout(
title=f"🌐 知识产权生态网络图谱 (显示前{node_limit}个节点)",
showlegend=True,
hovermode='closest',
template='plotly_white',
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
height=600,
font=dict(size=12)
)
return fig
# ==================== 模型训练模块 ====================
def train_model(task_type, model_type, epochs, hidden_channels, learning_rate, n_heads, num_layers):
"""统一的模型训练接口"""
global current_data, loaded_models, training_history
if current_data is None:
return "❌ 请先加载数据集!", None, None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
training_history[task_type] = []
if task_type == "链接预测":
model, predictor, history = train_link_prediction_task(
current_data, model_type, epochs, hidden_channels,
learning_rate, n_heads, num_layers, device
)
loaded_models['link_prediction'] = {'model': model, 'predictor': predictor}
elif task_type == "节点分类":
model, classifier, history = train_node_classification_task(
current_data, model_type, epochs, hidden_channels,
learning_rate, n_heads, num_layers, device
)
loaded_models['node_classification'] = {'model': model, 'classifier': classifier}
elif task_type == "专利价值评估":
model, value_predictor, history = train_patent_value_task(
current_data, epochs, hidden_channels,
learning_rate, n_heads, num_layers, device
)
loaded_models['patent_value'] = {'model': model, 'predictor': value_predictor}
elif task_type == "企业合作推荐":
model, collab_recommender, history = train_collaboration_task(
current_data, epochs, hidden_channels,
learning_rate, n_heads, num_layers, device
)
loaded_models['collaboration'] = {'model': model, 'recommender': collab_recommender}
training_history[task_type] = history
# 生成训练曲线
fig = plot_training_curves(history, task_type)
# 生成训练报告
report = generate_training_report(history, task_type)
return f"✅ {task_type}模型训练完成!", fig, report
except Exception as e:
return f"❌ 训练失败: {str(e)}", None, None
def train_link_prediction_task(data, model_type, epochs, hidden_channels, lr, n_heads, num_layers, device):
"""链接预测训练任务"""
edge_type = ('company', 'owns', 'patent')
edge_index = data[edge_type].edge_index
# 数据划分
num_edges = edge_index.size(1)
perm = torch.randperm(num_edges)
train_size = int(0.8 * num_edges)
train_edge_index = edge_index[:, perm[:train_size]]
val_edge_index = edge_index[:, perm[train_size:]]
# 初始化模型
if model_type == "HGT":
model = HGT(
hidden_channels=hidden_channels,
out_channels=hidden_channels,
num_layers=num_layers,
n_heads=n_heads,
dropout=0.2,
metadata=data.metadata()
).to(device)
predictor = HGTLinkPredictor(hidden_channels, hidden_channels // 2, 2).to(device)
else:
model = HeteroGNN(hidden_channels, num_layers, data.metadata()).to(device)
predictor = LinkPredictor(hidden_channels).to(device)
optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=lr)
data = data.to(device)
src_type, _, dst_type = edge_type
history = {'epoch': [], 'train_loss': [], 'val_auc': [], 'val_ap': []}
for epoch in range(epochs):
model.train()
predictor.train()
optimizer.zero_grad()
x_dict = model(data.x_dict, data.edge_index_dict)
pos_pred = predictor(x_dict[src_type], x_dict[dst_type], train_edge_index)
# 负采样
neg_edge_index = negative_sampling(train_edge_index, data[src_type].num_nodes,
data[dst_type].num_nodes, train_edge_index.size(1)).to(device)
neg_pred = predictor(x_dict[src_type], x_dict[dst_type], neg_edge_index)
loss = torch.nn.functional.binary_cross_entropy_with_logits(
torch.cat([pos_pred, neg_pred]),
torch.cat([torch.ones_like(pos_pred), torch.zeros_like(neg_pred)])
)
loss.backward()
optimizer.step()
# 验证
if epoch % 5 == 0:
model.eval()
predictor.eval()
with torch.no_grad():
x_dict = model(data.x_dict, data.edge_index_dict)
val_pos_pred = predictor(x_dict[src_type], x_dict[dst_type], val_edge_index)
val_neg_edge_index = negative_sampling(val_edge_index, data[src_type].num_nodes,
data[dst_type].num_nodes, val_edge_index.size(1)).to(device)
val_neg_pred = predictor(x_dict[src_type], x_dict[dst_type], val_neg_edge_index)
preds = torch.cat([val_pos_pred, val_neg_pred]).sigmoid().cpu().numpy()
labels = np.concatenate([np.ones(val_pos_pred.size(0)), np.zeros(val_neg_pred.size(0))])
from sklearn.metrics import roc_auc_score, average_precision_score
val_auc = roc_auc_score(labels, preds)
val_ap = average_precision_score(labels, preds)
history['epoch'].append(epoch)
history['train_loss'].append(loss.item())
history['val_auc'].append(val_auc)
history['val_ap'].append(val_ap)
return model, predictor, history
def train_node_classification_task(data, model_type, epochs, hidden_channels, lr, n_heads, num_layers, device):
"""节点分类训练任务"""
node_type = 'company'
target = 'industry'
labels = data[node_type][target]
num_classes = labels.max().item() + 1
num_nodes = data[node_type].num_nodes
# 数据划分
perm = torch.randperm(num_nodes)
train_size = int(0.6 * num_nodes)
val_size = int(0.2 * num_nodes)
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[perm[:train_size]] = True
val_mask[perm[train_size:train_size + val_size]] = True
# 初始化模型
if model_type == "HGT":
model = HGT(
hidden_channels=hidden_channels,
out_channels=hidden_channels,
num_layers=num_layers,
n_heads=n_heads,
dropout=0.2,
metadata=data.metadata()
).to(device)
classifier = HGTNodeClassifier(hidden_channels, num_classes, hidden_channels // 2, 2).to(device)
else:
model = HeteroGNN(hidden_channels, num_layers, data.metadata()).to(device)
classifier = NodeClassifier(hidden_channels, num_classes).to(device)
optimizer = torch.optim.Adam(list(model.parameters()) + list(classifier.parameters()), lr=lr)
data = data.to(device)
labels = labels.to(device)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
history = {'epoch': [], 'train_loss': [], 'train_acc': [], 'val_acc': []}
for epoch in range(epochs):
model.train()
classifier.train()
optimizer.zero_grad()
x_dict = model(data.x_dict, data.edge_index_dict)
out = classifier(x_dict[node_type])
loss = torch.nn.functional.cross_entropy(out[train_mask], labels[train_mask])
loss.backward()
optimizer.step()
if epoch % 5 == 0:
model.eval()
classifier.eval()
with torch.no_grad():
x_dict = model(data.x_dict, data.edge_index_dict)
out = classifier(x_dict[node_type])
pred = out.argmax(dim=1)
train_acc = (pred[train_mask] == labels[train_mask]).float().mean().item()
val_acc = (pred[val_mask] == labels[val_mask]).float().mean().item()
history['epoch'].append(epoch)
history['train_loss'].append(loss.item())
history['train_acc'].append(train_acc)
history['val_acc'].append(val_acc)
return model, classifier, history
def train_patent_value_task(data, epochs, hidden_channels, lr, n_heads, num_layers, device):
"""专利价值评估训练任务"""
data = data.to(device)
patent_features = data['patent'].x
value_labels = (patent_features[:, 1] * 30 + patent_features[:, 3] * 25 +
patent_features[:, 4] * 20 + patent_features[:, 6] * 15 +
torch.rand(len(patent_features), device=device) * 10)
value_labels = torch.clamp(value_labels, 0, 100)
num_patents = data['patent'].num_nodes
perm = torch.randperm(num_patents)
train_size = int(0.6 * num_patents)
val_size = int(0.2 * num_patents)
train_mask = torch.zeros(num_patents, dtype=torch.bool)
val_mask = torch.zeros(num_patents, dtype=torch.bool)
train_mask[perm[:train_size]] = True
val_mask[perm[train_size:train_size + val_size]] = True
model = HGT(
hidden_channels=hidden_channels,
out_channels=hidden_channels,
num_layers=num_layers,
n_heads=n_heads,
dropout=0.2,
metadata=data.metadata()
).to(device)
value_predictor = HGTPatentValuePredictor(hidden_channels, hidden_channels, 2).to(device)
optimizer = torch.optim.Adam(
list(model.parameters()) + list(value_predictor.parameters()),
lr=lr,
weight_decay=1e-4
)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
history = {'epoch': [], 'train_loss': [], 'val_mae': [], 'val_rmse': []}
for epoch in range(epochs):
model.train()
value_predictor.train()
optimizer.zero_grad()
x_dict = model(data.x_dict, data.edge_index_dict)
pred_values = value_predictor(x_dict['patent']).squeeze()
loss = torch.nn.functional.mse_loss(pred_values[train_mask], value_labels[train_mask])
loss.backward()
torch.nn.utils.clip_grad_norm_(
list(model.parameters()) + list(value_predictor.parameters()), 1.0
)
optimizer.step()
if epoch % 5 == 0:
model.eval()
value_predictor.eval()
with torch.no_grad():
x_dict = model(data.x_dict, data.edge_index_dict)
pred_values = value_predictor(x_dict['patent']).squeeze()
val_mae = torch.nn.functional.l1_loss(
pred_values[val_mask], value_labels[val_mask]
).item()
val_rmse = torch.sqrt(
torch.nn.functional.mse_loss(pred_values[val_mask], value_labels[val_mask])
).item()
history['epoch'].append(epoch)
history['train_loss'].append(loss.item())
history['val_mae'].append(val_mae)
history['val_rmse'].append(val_rmse)
return model, value_predictor, history
def train_collaboration_task(data, epochs, hidden_channels, lr, n_heads, num_layers, device):
"""企业合作推荐训练任务"""
data = data.to(device)
existing_edges = data['company', 'cooperates', 'company'].edge_index
num_companies = data['company'].num_nodes
pos_edges = existing_edges.t()
neg_edges = negative_sampling_collab(pos_edges, num_companies, len(pos_edges))
all_edges = torch.cat([pos_edges, neg_edges], dim=0)
labels = torch.cat([torch.ones(len(pos_edges)), torch.zeros(len(neg_edges))])
num_edges = len(all_edges)
perm = torch.randperm(num_edges)
train_size = int(0.6 * num_edges)
val_size = int(0.2 * num_edges)
train_edges = all_edges[perm[:train_size]].to(device)
val_edges = all_edges[perm[train_size:train_size + val_size]].to(device)
train_labels = labels[perm[:train_size]].to(device)
val_labels = labels[perm[train_size:train_size + val_size]].to(device)
model = HGT(
hidden_channels=hidden_channels,
out_channels=hidden_channels,
num_layers=num_layers,
n_heads=n_heads,
dropout=0.2,
metadata=data.metadata()
).to(device)
collab_recommender = HGTCollaborationRecommender(hidden_channels, hidden_channels, 2).to(device)
optimizer = torch.optim.Adam(list(model.parameters()) + list(collab_recommender.parameters()), lr=lr)
history = {'epoch': [], 'train_loss': [], 'val_auc': [], 'val_ap': []}
for epoch in range(epochs):
model.train()
collab_recommender.train()
optimizer.zero_grad()
x_dict = model(data.x_dict, data.edge_index_dict)
results = collab_recommender(x_dict['company'], x_dict['company'], train_edges.t())
loss = torch.nn.functional.binary_cross_entropy(results['success_probability'], train_labels)
loss.backward()
optimizer.step()
if epoch % 5 == 0:
model.eval()
collab_recommender.eval()
with torch.no_grad():
x_dict = model(data.x_dict, data.edge_index_dict)
val_results = collab_recommender(x_dict['company'], x_dict['company'], val_edges.t())
val_pred = val_results['success_probability']
from sklearn.metrics import roc_auc_score, average_precision_score
val_auc = roc_auc_score(val_labels.cpu().numpy(), val_pred.cpu().numpy())
val_ap = average_precision_score(val_labels.cpu().numpy(), val_pred.cpu().numpy())
history['epoch'].append(epoch)
history['train_loss'].append(loss.item())
history['val_auc'].append(val_auc)
history['val_ap'].append(val_ap)
return model, collab_recommender, history
def negative_sampling(edge_index, num_nodes_src, num_nodes_dst, num_neg):
"""负采样"""
neg_edges = []
while len(neg_edges) < num_neg:
src = torch.randint(0, num_nodes_src, (num_neg,))
dst = torch.randint(0, num_nodes_dst, (num_neg,))
neg = torch.stack([src, dst])
neg_edges.append(neg)
if len(neg_edges) * num_neg >= num_neg:
break
return torch.cat(neg_edges, dim=1)[:, :num_neg]
def negative_sampling_collab(pos_edges, num_nodes, num_neg):
"""合作推荐负采样"""
device = pos_edges.device
neg_edges = []
existing_set = set(map(tuple, pos_edges.tolist()))
while len(neg_edges) < num_neg:
src = torch.randint(0, num_nodes, (num_neg * 2,), device=device)
dst = torch.randint(0, num_nodes, (num_neg * 2,), device=device)
mask = src != dst
candidates = torch.stack([src[mask], dst[mask]], dim=1)
for edge in candidates:
edge_tuple = tuple(edge.tolist())
reverse_tuple = tuple(edge.flip(0).tolist())
if edge_tuple not in existing_set and reverse_tuple not in existing_set:
neg_edges.append(edge)
if len(neg_edges) >= num_neg:
break
if len(neg_edges) >= num_neg:
break
return torch.stack(neg_edges[:num_neg])
def plot_training_curves(history, task_type):
"""绘制训练曲线"""
fig = make_subplots(rows=1, cols=2, subplot_titles=('训练损失', '验证指标'))
fig.add_trace(
go.Scatter(x=history['epoch'], y=history['train_loss'],
mode='lines+markers', name='训练损失',
line=dict(color='#FF6B6B', width=2)),
row=1, col=1
)
if 'val_auc' in history:
fig.add_trace(
go.Scatter(x=history['epoch'], y=history['val_auc'],
mode='lines+markers', name='验证AUC',
line=dict(color='#4ECDC4', width=2)),
row=1, col=2
)
if 'val_ap' in history:
fig.add_trace(
go.Scatter(x=history['epoch'], y=history['val_ap'],
mode='lines+markers', name='验证AP',
line=dict(color='#95E1D3', width=2)),
row=1, col=2
)
elif 'val_acc' in history:
fig.add_trace(
go.Scatter(x=history['epoch'], y=history['train_acc'],
mode='lines+markers', name='训练准确率',
line=dict(color='#4ECDC4', width=2)),
row=1, col=2
)
fig.add_trace(
go.Scatter(x=history['epoch'], y=history['val_acc'],
mode='lines+markers', name='验证准确率',
line=dict(color='#95E1D3', width=2)),
row=1, col=2
)
elif 'val_mae' in history:
fig.add_trace(
go.Scatter(x=history['epoch'], y=history['val_mae'],
mode='lines+markers', name='验证MAE',
line=dict(color='#4ECDC4', width=2)),
row=1, col=2
)
fig.add_trace(
go.Scatter(x=history['epoch'], y=history['val_rmse'],
mode='lines+markers', name='验证RMSE',
line=dict(color='#95E1D3', width=2)),
row=1, col=2
)
fig.update_xaxes(title_text="训练轮次", row=1, col=1)
fig.update_xaxes(title_text="训练轮次", row=1, col=2)
fig.update_yaxes(title_text="损失值", row=1, col=1)
fig.update_layout(
title=f"📈 {task_type} - 训练过程监控",
template="plotly_white",
height=400,
showlegend=True,
font=dict(size=12)
)
return fig
def generate_training_report(history, task_type):
"""生成训练报告"""
if not history or 'epoch' not in history:
return "训练历史为空"
report = f"""
## 📋 {task_type} 训练报告
**训练配置**
- 总训练轮次: {history['epoch'][-1] + 1}
- 最终训练损失: {history['train_loss'][-1]:.4f}
**性能指标**
"""
if 'val_auc' in history:
report += f"- 最佳验证AUC: {max(history['val_auc']):.4f}\n"
report += f"- 最终验证AUC: {history['val_auc'][-1]:.4f}\n"
if 'val_ap' in history:
report += f"- 最佳验证AP: {max(history['val_ap']):.4f}\n"
elif 'val_acc' in history:
report += f"- 最佳验证准确率: {max(history['val_acc']):.4f}\n"
report += f"- 最终验证准确率: {history['val_acc'][-1]:.4f}\n"
elif 'val_mae' in history:
report += f"- 最佳验证MAE: {min(history['val_mae']):.4f}\n"
report += f"- 最终验证MAE: {history['val_mae'][-1]:.4f}\n"
report += f"\n**训练完成时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
return report
# ==================== 推理预测模块 ====================
def predict_link(company_id, num_predictions):
"""预测企业-专利链接"""
global current_data, loaded_models
if current_data is None or 'link_prediction' not in loaded_models:
return "❌ 请先加载数据并训练链接预测模型!", None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = current_data.to(device)
model = loaded_models['link_prediction']['model']
predictor = loaded_models['link_prediction']['predictor']
model.eval()
predictor.eval()
with torch.no_grad():
x_dict = model(data.x_dict, data.edge_index_dict)
num_patents = data['patent'].num_nodes
company_idx = min(company_id, data['company'].num_nodes - 1)
edge_candidates = torch.stack([
torch.full((num_patents,), company_idx, dtype=torch.long),
torch.arange(num_patents)
]).to(device)
scores = predictor(
x_dict['company'],
x_dict['patent'],
edge_candidates
).sigmoid().cpu().numpy()
top_indices = np.argsort(scores)[-num_predictions:][::-1]
results = []
for idx in top_indices:
results.append({
"专利ID": f"P-{idx}",
"预测得分": f"{scores[idx]:.4f}",
"置信度": f"{scores[idx] * 100:.2f}%"
})
df = pd.DataFrame(results)
fig = go.Figure(data=[
go.Bar(x=[r["专利ID"] for r in results],
y=[float(r["预测得分"]) for r in results],
marker=dict(color=[float(r["预测得分"]) for r in results],
colorscale='Viridis'))
])
fig.update_layout(
title=f"企业 C-{company_idx} 的专利关联预测 (Top-{num_predictions})",
xaxis_title="专利ID",
yaxis_title="预测得分",
template="plotly_white",
height=400
)
return df, fig
# ==================== 新功能1: 专利价值排行榜 ====================
def get_patent_value_leaderboard(top_n=50):
"""获取专利价值排行榜"""
global current_data, loaded_models
if current_data is None or 'patent_value' not in loaded_models:
return "❌ 请先加载数据并训练专利价值评估模型!", None, None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = current_data.to(device)
model = loaded_models['patent_value']['model']
value_predictor = loaded_models['patent_value']['predictor']
model.eval()
value_predictor.eval()
with torch.no_grad():
x_dict = model(data.x_dict, data.edge_index_dict)
values, grant_probs, renewal_probs = value_predictor(x_dict['patent'], return_aux=True)
values = values.squeeze().cpu().numpy()
grant_probs = grant_probs.squeeze().cpu().numpy()
renewal_probs = renewal_probs.squeeze().cpu().numpy()
# 排序
top_indices = np.argsort(values)[-top_n:][::-1]
# 构建排行榜
leaderboard = []
for rank, idx in enumerate(top_indices, 1):
patent_features = data['patent'].x[idx].cpu().numpy()
leaderboard.append({
"排名": rank,
"专利ID": f"P-{idx}",
"价值评分": f"{values[idx]:.2f}",
"授权概率": f"{grant_probs[idx]*100:.1f}%",
"续费概率": f"{renewal_probs[idx]*100:.1f}%",
"权利要求数": int(patent_features[1] * 30),
"引用数": int(patent_features[3] * 50),
"技术宽度": int(patent_features[4] * 5)
})
df = pd.DataFrame(leaderboard)
# 可视化 - Top 20
top_20 = leaderboard[:20]
fig = make_subplots(
rows=1, cols=2,
subplot_titles=('Top 20 专利价值分布', '价值与授权概率关系'),
specs=[[{'type': 'bar'}, {'type': 'scatter'}]]
)
# 柱状图
fig.add_trace(
go.Bar(
x=[item["专利ID"] for item in top_20],
y=[float(item["价值评分"]) for item in top_20],
marker=dict(
color=[float(item["价值评分"]) for item in top_20],
colorscale='Reds',
showscale=True
),
name='价值评分'
),
row=1, col=1
)
# 散点图
fig.add_trace(
go.Scatter(
x=[float(item["价值评分"]) for item in leaderboard],
y=[float(item["授权概率"].rstrip('%')) for item in leaderboard],
mode='markers',
marker=dict(
size=8,
color=[float(item["价值评分"]) for item in leaderboard],
colorscale='Viridis',
showscale=True
),
text=[item["专利ID"] for item in leaderboard],
name='专利分布'
),
row=1, col=2
)
fig.update_xaxes(title_text="专利ID", tickangle=45, row=1, col=1)
fig.update_xaxes(title_text="价值评分", row=1, col=2)
fig.update_yaxes(title_text="价值评分", row=1, col=1)
fig.update_yaxes(title_text="授权概率 (%)", row=1, col=2)
fig.update_layout(
title=f"💎 专利价值排行榜 (Top {top_n})",
template="plotly_white",
height=500,
showlegend=False
)
# 生成报告
avg_value = np.mean(values)
top10_avg = np.mean([float(item["价值评分"]) for item in leaderboard[:10]])
report = f"""
## 📊 专利价值排行榜分析
**整体概况**
- 总专利数量: {len(values):,}
- 平均价值评分: {avg_value:.2f}
- Top 10 平均评分: {top10_avg:.2f}
- 最高价值: {values[top_indices[0]]:.2f} (P-{top_indices[0]})
**价值分布**
- 高价值专利 (≥80分): {np.sum(values >= 80)} 个 ({np.sum(values >= 80)/len(values)*100:.1f}%)
- 中等价值 (60-80分): {np.sum((values >= 60) & (values < 80))}
- 一般价值 (<60分): {np.sum(values < 60)}
**生成时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
"""
return report, df, fig
# ==================== 新功能2: 基于描述的产业分类与专利推荐 ====================
def classify_and_recommend_by_description(description, num_patents=10):
"""根据描述分类产业并推荐专利"""
global current_data, loaded_models
if current_data is None or 'node_classification' not in loaded_models:
return "❌ 请先加载数据并训练节点分类模型!", None, None, None
# 基于关键词匹配推断产业
description_lower = description.lower()
industry_scores = {}
for industry, keywords in INDUSTRY_KEYWORDS.items():
score = sum(1 for keyword in keywords if keyword.lower() in description_lower)
industry_scores[industry] = score
# 获取最可能的产业
if max(industry_scores.values()) == 0:
return "❌ 无法从描述中识别产业类型,请提供更多产业相关关键词", None, None, None
predicted_industry = max(industry_scores, key=industry_scores.get)
industries_list = ['金融科技', '生物医药', '人工智能', '半导体', '新能源', '电子商务', '物流科技', '智能制造']
industry_idx = industries_list.index(predicted_industry)
# 生成分类报告
classification_report = f"""
## 🏷️ 产业分类结果
**输入描述**: {description}
**预测产业**: **{predicted_industry}**
**匹配关键词**:
"""
matched_keywords = [kw for kw in INDUSTRY_KEYWORDS[predicted_industry] if kw.lower() in description_lower]
classification_report += "- " + ", ".join(matched_keywords) if matched_keywords else "- (基于语义分析)"
classification_report += "\n\n**产业概率分布**:\n"
total_score = sum(industry_scores.values())
for ind, score in sorted(industry_scores.items(), key=lambda x: x[1], reverse=True):
prob = score / total_score * 100 if total_score > 0 else 0
classification_report += f"- {ind}: {prob:.1f}%\n"
# 推荐该产业的专利
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = current_data.to(device)
# 找到该产业的企业
company_industries = data['company'].industry.cpu().numpy()
industry_companies = np.where(company_industries == industry_idx)[0]
if len(industry_companies) == 0:
return classification_report, None, None, None
# 找到这些企业持有的专利
edge_index = data['company', 'owns', 'patent'].edge_index.cpu().numpy()
industry_patents = set()
for company_idx in industry_companies:
patent_mask = edge_index[0] == company_idx
industry_patents.update(edge_index[1][patent_mask].tolist())
industry_patents = list(industry_patents)
if len(industry_patents) == 0:
return classification_report, None, None, None
# 如果训练了专利价值模型,按价值排序
if 'patent_value' in loaded_models:
model = loaded_models['patent_value']['model']
value_predictor = loaded_models['patent_value']['predictor']
model.eval()
value_predictor.eval()
with torch.no_grad():
x_dict = model(data.x_dict, data.edge_index_dict)
all_values = value_predictor(x_dict['patent']).squeeze().cpu().numpy()
# 获取该产业专利的价值
patent_values = [(idx, all_values[idx]) for idx in industry_patents]
patent_values.sort(key=lambda x: x[1], reverse=True)
top_patents = patent_values[:num_patents]
else:
# 随机选择
top_patents = [(idx, 0) for idx in np.random.choice(industry_patents, min(num_patents, len(industry_patents)), replace=False)]
# 构建推荐列表
recommendations = []
for rank, (patent_idx, value) in enumerate(top_patents, 1):
patent_features = data['patent'].x[patent_idx].cpu().numpy()
recommendations.append({
"排名": rank,
"专利ID": f"P-{patent_idx}",
"价值评分": f"{value:.2f}" if value > 0 else "N/A",
"权利要求数": int(patent_features[1] * 30),
"引用数": int(patent_features[3] * 50),
"技术宽度": int(patent_features[4] * 5),
"授权状态": "已授权" if patent_features[6] > 0.5 else "未授权"
})
df = pd.DataFrame(recommendations)
# 可视化
fig = go.Figure()
if value > 0: # 有价值评分
fig.add_trace(go.Bar(
x=[r["专利ID"] for r in recommendations],
y=[float(r["价值评分"]) for r in recommendations],
marker=dict(
color=[float(r["价值评分"]) for r in recommendations],
colorscale='Greens'
),
name='价值评分'
))
fig.update_layout(yaxis_title="价值评分")
else:
fig.add_trace(go.Bar(
x=[r["专利ID"] for r in recommendations],
y=[r["引用数"] for r in recommendations],
marker=dict(color='#4ECDC4'),
name='引用数'
))
fig.update_layout(yaxis_title="引用数")
fig.update_layout(
title=f"📚 {predicted_industry} 产业推荐专利 (Top {num_patents})",
xaxis_title="专利ID",
template="plotly_white",
height=400
)
# 产业分布可视化
prob_fig = go.Figure(data=[
go.Bar(
x=list(industry_scores.keys()),
y=list(industry_scores.values()),
marker=dict(
color=['#FF6B6B' if ind == predicted_industry else '#E0E0E0' for ind in industry_scores.keys()]
)
)
])
prob_fig.update_layout(
title="🎯 产业匹配度分析",
xaxis_title="产业类别",
yaxis_title="匹配得分",
template="plotly_white",
height=350
)
return classification_report, df, fig, prob_fig
# ==================== 新功能3: 实体详情查看 ====================
def view_entity_details(entity_type, entity_id):
"""查看实体详细信息"""
global current_data
if current_data is None:
return "❌ 请先加载数据集!", None, None
try:
entity_id = int(entity_id)
except:
return "❌ 请输入有效的ID数字", None, None
type_mapping = {
"企业 (Company)": "company",
"专利 (Patent)": "patent",
"商标 (Trademark)": "trademark",
"人员 (Person)": "person",
"机构 (Institution)": "institution"
}
node_type = type_mapping.get(entity_type)
if node_type is None:
return "❌ 无效的实体类型", None, None
if entity_id >= current_data[node_type].num_nodes or entity_id < 0:
return f"❌ ID超出范围 (0-{current_data[node_type].num_nodes-1})", None, None
# 获取基本信息
features = current_data[node_type].x[entity_id].cpu().numpy()
# 构建详情报告
report = f"""
## 📋 {entity_type} 详细信息
**ID**: {node_type.upper()}-{entity_id}
### 基本特征
"""
# 根据不同类型显示不同特征
if node_type == "company":
industries = ['金融科技', '生物医药', '人工智能', '半导体', '新能源', '电子商务', '物流科技', '智能制造']
districts = ['中环', '湾仔', '尖沙咀', '观塘', '荃湾', '科学园', '数码港', '将军澳工业邨']
industry_idx = current_data['company'].industry[entity_id].item()
report += f"""
- **企业规模**: {int(features[0] * 500)}
- **成立年限**: {features[1] * 30:.1f}
- **研发投入比例**: {features[2] * 0.3 * 100:.1f}%
- **国际化程度**: {features[3] * 100:.1f}%
- **创新能力评分**: {features[4] * 100:.1f}/100
- **年营收**: {np.expm1(features[5] * 10):.1f} 百万港币
- **所属产业**: {industries[industry_idx]}
- **所在地区**: {districts[int(features[7] * len(districts))]}
"""
elif node_type == "patent":
report += f"""
- **申请年份**: {2015 + int(features[0] * 10)}
- **权利要求数**: {int(features[1] * 30)}
- **发明人数**: {int(features[2] * 10)}
- **引用数**: {int(features[3] * 50)}
- **技术宽度**: {int(features[4] * 5)}
- **价值评分**: {features[5] * 100:.1f}/100
- **授权状态**: {'✅ 已授权' if features[6] > 0.5 else '⏳ 未授权'}
- **IPC编码**: {int(features[7] * 100)}
"""
elif node_type == "trademark":
report += f"""
- **注册年份**: {2015 + int(features[0] * 10)}
- **商标类别**: {int(features[1] * 45) + 1}
- **续展次数**: {int(features[2] * 3)}
- **商标类型**: {['文字', '图形', '组合'][int(features[3] * 2)]}
- **知名度评分**: {features[4] * 100:.1f}/100
- **争议记录**: {int(features[5] * 10)}
"""
elif node_type == "person":
report += f"""
- **学历**: {['本科', '硕士', '博士'][int(features[0] * 2)]}
- **工作年限**: {features[1] * 40:.1f}
- **专利发明数**: {int(features[2] * 50)}
- **技术领域数**: {int(features[3] * 5)}
- **H指数**: {int(features[4] * 50)}
- **跨界合作能力**: {features[5] * 100:.1f}%
"""
elif node_type == "institution":
inst_types = ['大学', '研究所', '孵化器', '政府实验室']
report += f"""
- **机构类型**: {inst_types[int(features[0] * 4)]}
- **建立年限**: {features[1] * 100:.1f}
- **研究人员数**: {int(np.expm1(features[2] * np.log1p(1000)))}
- **年度专利产出**: {int(features[3] * 100)}
- **国际排名**: Top {int(1/features[4])}
- **产学研合作数**: {int(features[5] * 50)}
"""
# 获取关系信息
report += "\n### 关系网络\n"
relationships = []
for edge_type in current_data.edge_types:
src_type, rel, dst_type = edge_type
if src_type == node_type:
edge_index = current_data[edge_type].edge_index.cpu().numpy()
mask = edge_index[0] == entity_id
related_ids = edge_index[1][mask]
if len(related_ids) > 0:
relationships.append({
"关系": f"{rel}{dst_type}",
"数量": len(related_ids),
"示例": f"{dst_type.upper()}-{related_ids[0]}" if len(related_ids) > 0 else "N/A"
})
if dst_type == node_type:
edge_index = current_data[edge_type].edge_index.cpu().numpy()
mask = edge_index[1] == entity_id
related_ids = edge_index[0][mask]
if len(related_ids) > 0:
relationships.append({
"关系": f"{src_type}{rel}",
"数量": len(related_ids),
"示例": f"{src_type.upper()}-{related_ids[0]}" if len(related_ids) > 0 else "N/A"
})
if relationships:
rel_df = pd.DataFrame(relationships)
report += f"\n{rel_df.to_markdown(index=False)}\n"
else:
report += "- 暂无关系数据\n"
# 关系网络可视化
G = nx.Graph()
G.add_node(f"{node_type}-{entity_id}", node_type='center', color='#FF6B6B', size=20)
# 添加直接相关的节点(限制数量)
max_neighbors = 20
for edge_type in current_data.edge_types:
src_type, rel, dst_type = edge_type
if src_type == node_type:
edge_index = current_data[edge_type].edge_index.cpu().numpy()
mask = edge_index[0] == entity_id
related_ids = edge_index[1][mask][:max_neighbors]
for rid in related_ids:
G.add_node(f"{dst_type}-{rid}", node_type=dst_type, color='#4ECDC4', size=10)
G.add_edge(f"{node_type}-{entity_id}", f"{dst_type}-{rid}")
if dst_type == node_type:
edge_index = current_data[edge_type].edge_index.cpu().numpy()
mask = edge_index[1] == entity_id
related_ids = edge_index[0][mask][:max_neighbors]
for rid in related_ids:
G.add_node(f"{src_type}-{rid}", node_type=src_type, color='#45B7D1', size=10)
G.add_edge(f"{src_type}-{rid}", f"{node_type}-{entity_id}")
# 绘制图
pos = nx.spring_layout(G, k=1, iterations=50)
edge_trace = go.Scatter(
x=[], y=[],
line=dict(width=0.5, color='#888'),
hoverinfo='none',
mode='lines'
)
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace['x'] += tuple([x0, x1, None])
edge_trace['y'] += tuple([y0, y1, None])
node_trace = go.Scatter(
x=[], y=[],
mode='markers+text',
text=[],
textposition="top center",
marker=dict(size=[], color=[], line=dict(width=2, color='white')),
hoverinfo='text'
)
for node in G.nodes():
x, y = pos[node]
node_trace['x'] += tuple([x])
node_trace['y'] += tuple([y])
node_data = G.nodes[node]
node_trace['marker']['size'] += tuple([node_data.get('size', 10)])
node_trace['marker']['color'] += tuple([node_data.get('color', '#888')])
node_trace['text'] += tuple([node.split('-')[0]])
fig = go.Figure(data=[edge_trace, node_trace])
fig.update_layout(
title=f"🌐 {node_type.upper()}-{entity_id} 关系网络图谱",
showlegend=False,
hovermode='closest',
template='plotly_white',
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
height=500
)
return report, rel_df if relationships else None, fig
# ==================== 继续其他预测函数 ====================
def predict_node_class(company_id):
"""预测企业产业分类"""
global current_data, loaded_models
if current_data is None or 'node_classification' not in loaded_models:
return "❌ 请先加载数据并训练节点分类模型!", None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = current_data.to(device)
model = loaded_models['node_classification']['model']
classifier = loaded_models['node_classification']['classifier']
model.eval()
classifier.eval()
industries = ['金融科技', '生物医药', '人工智能', '半导体', '新能源', '电子商务', '物流科技', '智能制造']
company_idx = min(company_id, data['company'].num_nodes - 1)
with torch.no_grad():
x_dict = model(data.x_dict, data.edge_index_dict)
out = classifier(x_dict['company'])
probs = torch.softmax(out[company_idx], dim=0).cpu().numpy()
pred_class = probs.argmax()
results = []
for i, prob in enumerate(probs):
results.append({
"产业类别": industries[i] if i < len(industries) else f"产业{i}",
"预测概率": f"{prob:.4f}",
"百分比": f"{prob * 100:.2f}%"
})
df = pd.DataFrame(results).sort_values("预测概率", ascending=False)
fig = go.Figure(data=[
go.Bar(x=[r["产业类别"] for r in results],
y=[float(r["预测概率"]) for r in results],
marker=dict(color=['#FF6B6B' if i == pred_class else '#E0E0E0'
for i in range(len(results))]))
])
fig.update_layout(
title=f"企业 C-{company_idx} 的产业分类预测",
xaxis_title="产业类别",
yaxis_title="预测概率",
template="plotly_white",
height=400
)
return df, fig
def predict_patent_value(patent_id):
"""预测专利价值"""
global current_data, loaded_models
if current_data is None or 'patent_value' not in loaded_models:
return "❌ 请先加载数据并训练专利价值评估模型!", None, None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = current_data.to(device)
model = loaded_models['patent_value']['model']
value_predictor = loaded_models['patent_value']['predictor']
model.eval()
value_predictor.eval()
patent_idx = min(patent_id, data['patent'].num_nodes - 1)
with torch.no_grad():
x_dict = model(data.x_dict, data.edge_index_dict)
value, grant_prob, renewal_prob = value_predictor(x_dict['patent'], return_aux=True)
value_score = value[patent_idx].item()
grant_p = grant_prob[patent_idx].item()
renewal_p = renewal_prob[patent_idx].item()
report = f"""
## 💎 专利价值评估报告
**专利ID**: P-{patent_idx}
### 核心指标
- **综合价值评分**: {value_score:.2f} / 100
- **授权概率**: {grant_p * 100:.2f}%
- **续费概率**: {renewal_p * 100:.2f}%
### 价值等级
"""
if value_score >= 80:
report += "🌟 **高价值专利** - 建议重点保护和商业化开发"
elif value_score >= 60:
report += "⭐ **中等价值专利** - 具有一定商业潜力"
else:
report += "📄 **一般专利** - 基础性专利,价值有限"
fig = go.Figure()
fig.add_trace(go.Indicator(
mode="gauge+number+delta",
value=value_score,
domain={'x': [0, 0.5], 'y': [0, 1]},
title={'text': "综合价值评分"},
gauge={
'axis': {'range': [None, 100]},
'bar': {'color': "#4ECDC4"},
'steps': [
{'range': [0, 40], 'color': "#FFE5E5"},
{'range': [40, 70], 'color': "#FFF5CC"},
{'range': [70, 100], 'color': "#E8F5E9"}
],
'threshold': {
'line': {'color': "red", 'width': 4},
'thickness': 0.75,
'value': 80
}
}
))
fig.add_trace(go.Bar(
x=['授权概率', '续费概率'],
y=[grant_p * 100, renewal_p * 100],
marker=dict(color=['#FF6B6B', '#45B7D1']),
text=[f"{grant_p * 100:.1f}%", f"{renewal_p * 100:.1f}%"],
textposition='outside'
))
fig.update_layout(
title=f"专利 P-{patent_idx} 价值分析",
template="plotly_white",
height=400,
xaxis={'domain': [0.6, 1]},
yaxis={'domain': [0, 1], 'title': '概率 (%)'}
)
patent_features = data['patent'].x[patent_idx].cpu().numpy()
feature_names = ['申请年份', '权利要求数', '发明人数', '引用数', '技术宽度', '现有价值', '授权状态', 'IPC编码']
feature_df = pd.DataFrame({
'特征': feature_names,
'数值': [f"{val:.3f}" for val in patent_features]
})
return report, fig, feature_df
def recommend_collaboration(company_id, num_recommendations):
"""推荐企业合作伙伴"""
global current_data, loaded_models
if current_data is None or 'collaboration' not in loaded_models:
return "❌ 请先加载数据并训练合作推荐模型!", None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = current_data.to(device)
model = loaded_models['collaboration']['model']
collab_recommender = loaded_models['collaboration']['recommender']
model.eval()
collab_recommender.eval()
company_idx = min(company_id, data['company'].num_nodes - 1)
num_companies = data['company'].num_nodes
with torch.no_grad():
x_dict = model(data.x_dict, data.edge_index_dict)
edge_candidates = torch.stack([
torch.full((num_companies,), company_idx, dtype=torch.long),
torch.arange(num_companies)
]).to(device)
results = collab_recommender(x_dict['company'], x_dict['company'], edge_candidates)
success_probs = results['success_probability'].cpu().numpy()
tech_sims = results['tech_similarity'].cpu().numpy()
market_sims = results['market_similarity'].cpu().numpy()
complements = results['complementarity'].cpu().numpy()
success_probs[company_idx] = 0
top_indices = np.argsort(success_probs)[-num_recommendations:][::-1]
recommendations = []
for idx in top_indices:
recommendations.append({
"合作企业ID": f"C-{idx}",
"成功概率": f"{success_probs[idx]:.4f}",
"技术相似度": f"{tech_sims[idx]:.4f}",
"市场相似度": f"{market_sims[idx]:.4f}",
"互补性": f"{complements[idx]:.4f}"
})
df = pd.DataFrame(recommendations)
fig = go.Figure()
for i, rec in enumerate(recommendations[:5]):
fig.add_trace(go.Scatterpolar(
r=[float(rec["成功概率"]), float(rec["技术相似度"]),
float(rec["市场相似度"]), float(rec["互补性"])],
theta=['成功概率', '技术相似度', '市场相似度', '互补性'],
fill='toself',
name=rec["合作企业ID"]
))
fig.update_layout(
polar=dict(radialaxis=dict(visible=True, range=[0, 1])),
title=f"企业 C-{company_idx} 的合作推荐分析 (Top-{min(5, num_recommendations)})",
template="plotly_white",
height=500
)
return df, fig
# ==================== Gradio界面构建 ====================
def build_gradio_app():
"""构建增强版Gradio应用"""
with gr.Blocks(title="香港知识产权生态网络分析系统 (数据科学增强版)", theme=gr.themes.Soft()) as app:
gr.Markdown("""
# 🏙️ 香港知识产权生态网络分析系统
### 基于异构图神经网络的智能分析平台 - 数据科学增强版 v3.0
""")
with gr.Tabs():
# ========== 新增: 数据科学分析Tab ==========
with gr.Tab("📊 数据管理"):
gr.Markdown("## 数据集生成与加载")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 快速加载")
dataset_size = gr.Dropdown(
choices=["test", "medium", "large"],
value="medium",
label="选择数据集规模"
)
load_btn = gr.Button("📂 加载数据集", variant="primary")
gr.Markdown("### 自定义生成")
with gr.Accordion("高级配置", open=False):
custom_size = gr.Textbox(value="custom", label="数据集名称")
n_companies = gr.Slider(100, 2000, value=500, step=100, label="企业数量")
n_patents = gr.Slider(500, 15000, value=3000, step=500, label="专利数量")
n_trademarks = gr.Slider(200, 8000, value=1500, step=200, label="商标数量")
n_persons = gr.Slider(300, 10000, value=2000, step=300, label="人员数量")
n_institutions = gr.Slider(20, 200, value=50, step=10, label="机构数量")
time_span = gr.Slider(5, 20, value=10, step=1, label="时间跨度(年)")
generate_btn = gr.Button("🔨 生成数据集", variant="secondary")
with gr.Column(scale=2):
data_status = gr.Textbox(label="状态信息", interactive=False)
data_stats = gr.Dataframe(label="数据集统计")
with gr.Row():
detailed_stats = gr.Markdown(label="详细统计报告")
with gr.Row():
node_chart = gr.Plot(label="节点分布")
edge_chart = gr.Plot(label="关系分布")
with gr.Row():
network_viz = gr.Plot(label="网络图谱可视化")
# 绑定事件
load_btn.click(
load_dataset,
inputs=[dataset_size],
outputs=[data_status, data_stats, detailed_stats]
).then(
lambda: [visualize_network_overview(), visualize_edge_distribution(), visualize_network_graph()],
outputs=[node_chart, edge_chart, network_viz]
)
generate_btn.click(
generate_dataset,
inputs=[custom_size, n_companies, n_patents, n_trademarks, n_persons, n_institutions, time_span],
outputs=[data_status, data_stats, detailed_stats]
).then(
lambda: [visualize_network_overview(), visualize_edge_distribution(), visualize_network_graph()],
outputs=[node_chart, edge_chart, network_viz]
)
with gr.Tab("🔬 数据科学分析"):
gr.Markdown("## 高级数据分析与可视化")
with gr.Tabs():
# PCA分析
with gr.Tab("📐 PCA降维"):
gr.Markdown("### 主成分分析 (Principal Component Analysis)")
with gr.Row():
with gr.Column(scale=1):
pca_node_type = gr.Dropdown(
choices=['company', 'patent', 'trademark', 'person', 'institution'],
value='company',
label="选择节点类型"
)
pca_n_components = gr.Slider(
2, 10, value=2, step=1,
label="降维目标维度"
)
pca_run_btn = gr.Button("🚀 执行PCA分析", variant="primary", size="lg")
gr.Markdown("""
**PCA说明**:
- 线性降维方法
- 保留最大方差
- 适合数据预处理
""")
with gr.Column(scale=2):
pca_report = gr.Markdown(label="分析报告")
with gr.Row():
pca_variance_plot = gr.Plot(label="方差分析")
pca_projection_plot = gr.Plot(label="PCA投影")
pca_components_table = gr.Dataframe(label="主成分载荷 (Top 10特征)")
pca_run_btn.click(
perform_pca_analysis,
inputs=[pca_node_type, pca_n_components],
outputs=[pca_report, pca_variance_plot, pca_projection_plot, pca_components_table]
)
# t-SNE分析
with gr.Tab("🎯 t-SNE降维"):
gr.Markdown("### t-分布随机邻域嵌入 (t-SNE)")
with gr.Row():
with gr.Column(scale=1):
tsne_node_type = gr.Dropdown(
choices=['company', 'patent', 'trademark', 'person', 'institution'],
value='company',
label="选择节点类型"
)
tsne_n_components = gr.Radio(
choices=[2, 3],
value=2,
label="降维维度"
)
tsne_perplexity = gr.Slider(
5, 50, value=30, step=5,
label="困惑度 (Perplexity)"
)
tsne_n_iter = gr.Slider(
250, 2000, value=1000, step=250,
label="迭代次数"
)
tsne_run_btn = gr.Button("🚀 执行t-SNE分析", variant="primary", size="lg")
gr.Markdown("""
**t-SNE说明**:
- 非线性降维
- 保留局部结构
- 适合聚类可视化
⚠️ 计算较慢,大数据集会自动采样
""")
with gr.Column(scale=2):
tsne_report = gr.Markdown(label="分析报告")
tsne_plot = gr.Plot(label="t-SNE可视化")
tsne_metrics = gr.Dataframe(label="聚类质量指标")
tsne_run_btn.click(
perform_tsne_analysis,
inputs=[tsne_node_type, tsne_n_components, tsne_perplexity, tsne_n_iter],
outputs=[tsne_report, tsne_plot, tsne_metrics]
)
# 聚类分析
with gr.Tab("🎯 聚类分析"):
gr.Markdown("### 无监督聚类")
with gr.Row():
with gr.Column(scale=1):
cluster_node_type = gr.Dropdown(
choices=['company', 'patent', 'trademark', 'person', 'institution'],
value='company',
label="选择节点类型"
)
cluster_method = gr.Dropdown(
choices=['kmeans', 'dbscan', 'hierarchical'],
value='kmeans',
label="聚类方法"
)
cluster_n_clusters = gr.Slider(
2, 20, value=5, step=1,
label="聚类数量 (K-means/层次聚类)"
)
cluster_run_btn = gr.Button("🚀 执行聚类分析", variant="primary", size="lg")
gr.Markdown("""
**聚类方法**:
- **K-means**: 快速,需指定K值
- **DBSCAN**: 基于密度,自动确定簇数
- **层次聚类**: 层次结构,需指定K值
""")
with gr.Column(scale=2):
cluster_report = gr.Markdown(label="聚类报告")
cluster_plot = gr.Plot(label="聚类可视化")
cluster_stats = gr.Dataframe(label="聚类统计")
cluster_run_btn.click(
perform_clustering_analysis,
inputs=[cluster_node_type, cluster_method, cluster_n_clusters],
outputs=[cluster_report, cluster_plot, cluster_stats]
)
# 相关性分析
with gr.Tab("📊 相关性分析"):
gr.Markdown("### 特征相关性分析")
with gr.Row():
with gr.Column(scale=1):
corr_node_type = gr.Dropdown(
choices=['company', 'patent', 'trademark', 'person', 'institution'],
value='company',
label="选择节点类型"
)
corr_run_btn = gr.Button("🚀 分析特征相关性", variant="primary", size="lg")
with gr.Column(scale=2):
corr_report = gr.Markdown(label="相关性报告")
with gr.Row():
corr_heatmap = gr.Plot(label="相关性热力图")
corr_dist_plot = gr.Plot(label="特征分布")
corr_high_table = gr.Dataframe(label="高相关性特征对")
corr_run_btn.click(
perform_correlation_analysis,
inputs=[corr_node_type],
outputs=[corr_report, corr_heatmap, corr_high_table, corr_dist_plot]
)
# 统计仪表板
with gr.Tab("📈 统计仪表板"):
gr.Markdown("### 数据集整体统计概览")
stats_run_btn = gr.Button("📊 生成统计仪表板", variant="primary", size="lg")
stats_report = gr.Markdown(label="统计报告")
stats_dashboard = gr.Plot(label="统计仪表板")
stats_run_btn.click(
generate_statistics_dashboard,
inputs=[],
outputs=[stats_report, stats_dashboard]
)
# ========== Tab 2: 模型训练 ==========
with gr.Tab("🎯 模型训练"):
gr.Markdown("## 异构图神经网络模型训练")
with gr.Row():
with gr.Column(scale=1):
task_type = gr.Dropdown(
choices=["链接预测", "节点分类", "专利价值评估", "企业合作推荐"],
value="链接预测",
label="选择任务类型"
)
model_type = gr.Dropdown(
choices=["HGT", "HeteroGNN"],
value="HGT",
label="选择模型"
)
gr.Markdown("### 训练参数")
epochs = gr.Slider(10, 200, value=50, step=10, label="训练轮次")
hidden_channels = gr.Slider(32, 256, value=64, step=32, label="隐藏层维度")
learning_rate = gr.Slider(0.0001, 0.01, value=0.001, step=0.0001, label="学习率")
n_heads = gr.Slider(2, 16, value=8, step=2, label="注意力头数(HGT)")
num_layers = gr.Slider(1, 5, value=3, step=1, label="网络层数")
train_btn = gr.Button("🚀 开始训练", variant="primary", size="lg")
with gr.Column(scale=2):
train_status = gr.Textbox(label="训练状态", interactive=False)
train_curves = gr.Plot(label="训练曲线")
train_report = gr.Markdown(label="训练报告")
train_btn.click(
train_model,
inputs=[task_type, model_type, epochs, hidden_channels, learning_rate, n_heads, num_layers],
outputs=[train_status, train_curves, train_report]
)
# ========== Tab 3: 链接预测 ==========
with gr.Tab("🔗 链接预测"):
gr.Markdown("## 企业-专利关系预测")
with gr.Row():
with gr.Column(scale=1):
link_company_id = gr.Number(value=0, label="企业ID", precision=0)
link_num_pred = gr.Slider(5, 50, value=10, step=5, label="预测数量")
link_predict_btn = gr.Button("🔍 预测关联专利", variant="primary")
with gr.Column(scale=2):
link_results = gr.Dataframe(label="预测结果")
link_viz = gr.Plot(label="预测可视化")
link_predict_btn.click(
predict_link,
inputs=[link_company_id, link_num_pred],
outputs=[link_results, link_viz]
)
# ========== Tab 4: 产业分类与专利推荐 (新功能2) ==========
with gr.Tab("🏷️ 智能产业分析"):
gr.Markdown("## 基于描述的产业分类与专利推荐")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 方式1: 查看现有企业")
class_company_id = gr.Number(value=0, label="企业ID", precision=0)
class_predict_btn = gr.Button("🔍 分析企业产业", variant="secondary")
with gr.Column(scale=2):
class_results = gr.Dataframe(label="分类结果")
class_viz = gr.Plot(label="概率分布")
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 方式2: 输入业务描述")
business_desc = gr.Textbox(
label="业务描述",
placeholder="例如: 我们公司专注于区块链技术和数字支付解决方案...",
lines=5
)
num_patent_rec = gr.Slider(5, 30, value=10, step=5, label="推荐专利数量")
classify_btn = gr.Button("🎯 分析并推荐专利", variant="primary")
with gr.Column(scale=2):
classification_report = gr.Markdown(label="产业分类报告")
patent_recommendations = gr.Dataframe(label="推荐专利列表")
with gr.Row():
industry_prob_viz = gr.Plot(label="产业匹配度")
patent_rec_viz = gr.Plot(label="推荐专利分析")
# 绑定事件
class_predict_btn.click(
predict_node_class,
inputs=[class_company_id],
outputs=[class_results, class_viz]
)
classify_btn.click(
classify_and_recommend_by_description,
inputs=[business_desc, num_patent_rec],
outputs=[classification_report, patent_recommendations, patent_rec_viz, industry_prob_viz]
)
# ========== Tab 5: 专利价值评估 (增强版 - 新功能1) ==========
with gr.Tab("💎 专利价值"):
gr.Markdown("## 专利价值智能评估")
with gr.Tabs():
with gr.Tab("📋 价值排行榜"):
gr.Markdown("### 全局专利价值排行榜")
with gr.Row():
leaderboard_top_n = gr.Slider(10, 100, value=50, step=10, label="显示Top N专利")
leaderboard_btn = gr.Button("📊 生成排行榜", variant="primary", size="lg")
leaderboard_report = gr.Markdown(label="排行榜分析报告")
leaderboard_table = gr.Dataframe(label="专利价值排行榜")
leaderboard_viz = gr.Plot(label="排行榜可视化")
leaderboard_btn.click(
get_patent_value_leaderboard,
inputs=[leaderboard_top_n],
outputs=[leaderboard_report, leaderboard_table, leaderboard_viz]
)
with gr.Tab("🔍 单个专利评估"):
gr.Markdown("### 查询指定专利的详细价值")
with gr.Row():
with gr.Column(scale=1):
value_patent_id = gr.Number(value=0, label="专利ID", precision=0)
value_predict_btn = gr.Button("📊 评估专利价值", variant="primary")
with gr.Column(scale=2):
value_report = gr.Markdown(label="评估报告")
value_viz = gr.Plot(label="价值分析")
value_features = gr.Dataframe(label="专利特征")
value_predict_btn.click(
predict_patent_value,
inputs=[value_patent_id],
outputs=[value_report, value_viz, value_features]
)
# ========== Tab 6: 合作推荐 ==========
with gr.Tab("🤝 合作推荐"):
gr.Markdown("## 企业合作伙伴智能推荐")
with gr.Row():
with gr.Column(scale=1):
collab_company_id = gr.Number(value=0, label="企业ID", precision=0)
collab_num_rec = gr.Slider(5, 20, value=10, step=5, label="推荐数量")
collab_recommend_btn = gr.Button("🎯 推荐合作伙伴", variant="primary")
with gr.Column(scale=2):
collab_results = gr.Dataframe(label="推荐结果")
collab_viz = gr.Plot(label="多维分析")
collab_recommend_btn.click(
recommend_collaboration,
inputs=[collab_company_id, collab_num_rec],
outputs=[collab_results, collab_viz]
)
# ========== Tab 7: 实体详情查看 (新功能4) ==========
with gr.Tab("🔎 实体详情"):
gr.Markdown("## 查看任意实体的详细信息")
with gr.Row():
with gr.Column(scale=1):
entity_type_select = gr.Dropdown(
choices=[
"企业 (Company)",
"专利 (Patent)",
"商标 (Trademark)",
"人员 (Person)",
"机构 (Institution)"
],
value="企业 (Company)",
label="实体类型"
)
entity_id_input = gr.Number(value=0, label="实体ID", precision=0)
view_details_btn = gr.Button("🔍 查看详情", variant="primary", size="lg")
with gr.Column(scale=2):
entity_details_report = gr.Markdown(label="实体详细信息")
entity_relations_table = gr.Dataframe(label="关系统计")
with gr.Row():
entity_network_viz = gr.Plot(label="关系网络图谱")
view_details_btn.click(
view_entity_details,
inputs=[entity_type_select, entity_id_input],
outputs=[entity_details_report, entity_relations_table, entity_network_viz]
)
with gr.Tab("ℹ️ 系统信息"):
gr.Markdown("""
## 📚 系统说明 (v3.0 - 数据科学增强版)
### 🆕 最新更新 (v3.0)
1. **PCA降维分析** - 主成分分析,查看方差解释和特征重要性
2. **t-SNE可视化** - 非线性降维,发现数据聚类结构
3. **聚类分析** - K-means、DBSCAN、层次聚类
4. **相关性分析** - 特征相关性热力图和高相关特征识别
5. **统计仪表板** - 数据集全局统计概览
### 🔬 数据科学功能
- **降维**: PCA、t-SNE支持2D/3D可视化
- **聚类**: 多种聚类算法,自动计算聚类质量指标
- **统计**: 相关性分析、分布分析、质量评估
- **可视化**: 交互式图表,支持缩放、悬停查看详情
### 📊 核心功能模块
1. **数据管理** - 数据生成、加载、统计
2. **数据科学分析** - PCA、t-SNE、聚类、相关性 (NEW!)
3. **模型训练** - HGT/HeteroGNN多任务训练
4. **链接预测** - 企业-专利关系预测
5. **智能产业分析** - 产业分类+专利推荐
6. **专利价值** - 排行榜+单个评估
7. **合作推荐** - 多维度企业合作分析
8. **实体详情** - 完整实体信息查看
### 🛠️ 技术栈
- **深度学习**: PyTorch, PyTorch Geometric
- **图模型**: HGT, HeteroGNN
- **数据科学**: Scikit-learn, UMAP
- **可视化**: Plotly, NetworkX
- **界面**: Gradio
### 📖 使用指南
1. **数据准备**: 在"数据管理"加载数据集
2. **数据探索**: 在"数据科学分析"进行降维、聚类等分析
3. **模型训练**: 在"模型训练"训练所需任务模型
4. **应用分析**: 在各功能标签页进行预测和推荐
### 📈 典型工作流
```
加载数据 → 数据分析 (PCA/t-SNE) → 模型训练 → 业务预测
↓ ↓ ↓ ↓
统计分析 发现模式 优化模型 决策支持
```
### 🔗 更新日志
**v3.0** (2025-10-27)
- ✨ 新增完整的数据科学分析模块
- 📊 支持PCA、t-SNE降维可视化
- 🎯 支持多种聚类算法
- 📈 新增特征相关性分析
- 📉 新增统计分析仪表板
- 🎨 优化可视化效果和交互体验
**v2.0** (2025-10-26)
- 专利价值排行榜
- 智能产业分析
- 实体详情查看
**v1.0** (2025-10-25)
- 基础图神经网络模型
- 链接预测和节点分类
---
**开发团队**: Math3836 Team | **版本**: v3.0 | **日期**: 2025-10-27
💡 **提示**: 数据科学分析功能计算密集,大数据集可能需要较长时间
""")
gr.Markdown("""
---
🎓 **学习建议**:
- 🔬 先使用"数据科学分析"探索数据特性
- 📊 通过PCA了解特征重要性
- 🎯 用t-SNE发现数据聚类模式
- 📈 结合聚类分析验证模型效果
- 🔗 最后应用训练好的模型进行预测
""")
return app
# ==================== 启动应用 ==========
if __name__ == "__main__":
os.makedirs('data', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
app = build_gradio_app()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)