PatentConceptor / app.py
LiuHongwei1992's picture
Upload app.py with huggingface_hub
0ad783f verified
"""
PatentConceptor - HuggingFace Spaces Gradio App
Patent-Based Cross-Domain Design Concept Generation System
"""
import os
import gradio as gr
from pathlib import Path
# ---- 配置 ----
# Online demo uses 10K for fast loading.
# For local deployment with 100K/500K, download from the project repository
# and place the CSV files in the data/ directory.
DATA_OPTIONS = {
"10K (Quick Test, ~9MB)": "patents_sample_10k.csv",
}
DEFAULT_DATA = "patents_sample_100k.csv"
DATA_DIR = Path("data")
# 全局缓存: {data_path: kb_instance}
_kb_cache = {}
_llm = None
_current_data_file = None
def _get_env_api_key():
return os.getenv("SILICONFLOW_API_KEY", "").strip()
def init_system(data_choice: str, api_key: str):
"""初始化知识库和LLM客户端(带缓存,避免重复加载)"""
global _kb_cache, _llm, _current_data_file
import patent_kb_v2 as kb_mod
from llm_client import SiliconFlowClient
from patent_kb_v2 import PatentKnowledgeBaseV2
data_file = DATA_OPTIONS.get(data_choice, DEFAULT_DATA)
data_path = DATA_DIR / data_file
if not data_path.exists():
return (
f"❌ Data file not found: {data_path}\n"
f"Please copy patent CSV files to the data/ directory.",
None, None,
)
key = api_key.strip() or _get_env_api_key()
if not key:
return "⚠️ Please provide a SiliconFlow API Key", None, None
os.environ["SILICONFLOW_API_KEY"] = key
# 检查是否已缓存相同数据集
cache_key = str(data_path)
if cache_key in _kb_cache and _llm is not None:
_current_data_file = data_file
return (
f"✅ Using cached dataset\n"
f" Dataset: {data_file}\n"
f" Patents: {len(_kb_cache[cache_key].patents)}\n"
f" API Key: configured",
_kb_cache[cache_key],
_llm,
)
try:
_llm = SiliconFlowClient(api_key=key)
kb_mod._llm_client = _llm
kb = PatentKnowledgeBaseV2(str(data_path))
kb.load()
# 初始化 LSA + Hybrid Search Engine
from semantic_search_sklearn import LSASemanticSearch, HybridSearchEngineLSA
semantic_engine = LSASemanticSearch(n_components=100)
semantic_engine.build_index(kb.patents)
hybrid_engine = HybridSearchEngineLSA(kb, semantic_engine, alpha=0.6)
hybrid_engine.set_llm_client(_llm)
kb.set_hybrid_engine(hybrid_engine)
_kb_cache[cache_key] = kb
_current_data_file = data_file
info = (
f"✅ System initialized (first load, may take 1-3 min for 500K)\n"
f" Dataset: {data_file}\n"
f" Patents: {len(kb.patents)}\n"
f" Hybrid Search: enabled\n"
f" API Key: configured"
)
return info, kb, _llm
except Exception as e:
return f"❌ Initialization failed: {str(e)}", None, None
def run_single_mode(query: str, top_k: int, kb, llm):
"""Single Mode: deepen one high-relevance patent (retrieval shows all domains)"""
if not kb or not llm:
return "⚠️ Please initialize the system first", ""
try:
# Domain analysis
source_domains = kb.analyze_domains(query)
target_domains = kb.get_cross_domain_targets(source_domains)
# Hybrid search: source domains + cross-domain targets
all_patents = []
for section in source_domains:
try:
results = kb.search_hybrid(query, section, top_k=top_k)
for r in results:
r["domain_type"] = f"same_{section}"
all_patents.extend(results)
except Exception as e:
print(f"Error searching {section}: {e}")
for section in target_domains[:3]:
try:
results = kb.search_hybrid(query, section, top_k=top_k)
for r in results:
r["domain_type"] = f"cross_{section}"
r["score"] = r.get("score", 0) * 0.9
all_patents.extend(results)
except Exception as e:
print(f"Error searching cross {section}: {e}")
# Deduplicate and sort
seen = set()
unique = []
for p in sorted(all_patents, key=lambda x: x.get("score", 0), reverse=True):
title = p.get("title", "")
if title and title not in seen:
seen.add(title)
unique.append(p)
patents = unique[:top_k]
# Generate concept using the top-1 patent (single mode semantics)
top_patent = [patents[0]] if patents else []
concept = llm.generate_design_concept(
query=query,
patents=top_patent,
source_domains=source_domains,
target_domains=target_domains,
mode="single",
)
# Format output
domain_info = f"📊 Domains: {source_domains}\n🎯 Cross-domain Targets: {target_domains}"
# 提取查询配置信息(从首个结果中获取)
config_info = ""
if patents:
first = patents[0]
qtype = first.get('query_type', 'N/A')
alpha = first.get('alpha_used', None)
squery = first.get('search_query', query)
config_info += f"📝 Query Type: {qtype}\n"
if alpha is not None:
config_info += f"⚖️ Hybrid Weight: TF-IDF {round(alpha*100)}% + LSA {round((1-alpha)*100)}%\n"
config_info += f"🔤 Functional Keywords: {squery}\n"
same_count = sum(1 for p in patents if p.get("domain_type", "").startswith("same_"))
cross_count = sum(1 for p in patents if p.get("domain_type", "").startswith("cross_"))
patents_info = f"🔍 Retrieved Patents (same: {same_count}, cross: {cross_count}):\n"
for i, r in enumerate(patents, 1):
title = r.get("title", "N/A")[:55]
ipc = r.get("ipc", "N/A")
score = r.get("score", 0)
dtype = "same" if r.get("domain_type", "").startswith("same_") else "cross"
patents_info += f" {i}. [{dtype}] [{ipc}] {title} (score={score:.3f})\n"
concept_text = concept.get("design_text", "")
concept_info = f"💡 Design Concept ({concept.get('mode_name', 'Single')}):\n\n{concept_text}"
return domain_info + "\n" + config_info + "\n" + patents_info, concept_info
except Exception as e:
return f"❌ Error: {str(e)}", ""
def run_multi_mode(query: str, top_k: int, kb, llm):
"""Multi Mode: fuse cross-domain patents"""
if not kb or not llm:
return "⚠️ Please initialize the system first", ""
try:
# Domain analysis
source_domains = kb.analyze_domains(query)
target_domains = kb.get_cross_domain_targets(source_domains)
# Hybrid search: source + cross-domain
all_patents = []
for section in source_domains:
try:
results = kb.search_hybrid(query, section, top_k=top_k)
for r in results:
r["domain_type"] = f"same_{section}"
all_patents.extend(results)
except Exception as e:
print(f"Error searching {section}: {e}")
for section in target_domains[:3]:
try:
results = kb.search_hybrid(query, section, top_k=top_k)
for r in results:
r["domain_type"] = f"cross_{section}"
r["score"] = r.get("score", 0) * 0.9
all_patents.extend(results)
except Exception as e:
print(f"Error searching cross {section}: {e}")
# Separate and deduplicate
same = [p for p in all_patents if p.get("domain_type", "").startswith("same_")]
cross = [p for p in all_patents if p.get("domain_type", "").startswith("cross_")]
seen_same = set()
unique_same = []
for p in sorted(same, key=lambda x: x.get("score", 0), reverse=True):
title = p.get("title", "")
if title and title not in seen_same:
seen_same.add(title)
unique_same.append(p)
seen_cross = set()
unique_cross = []
for p in sorted(cross, key=lambda x: x.get("score", 0), reverse=True):
title = p.get("title", "")
if title and title not in seen_cross:
seen_cross.add(title)
unique_cross.append(p)
patents = unique_same[:4] + unique_cross[:4]
if len(unique_cross) < 2:
patents = unique_same[:6] + unique_cross[:2]
# Generate concept
concept = llm.generate_design_concept(
query=query,
patents=patents,
source_domains=source_domains,
target_domains=target_domains,
mode="multi",
)
# 提取查询配置信息
config_info = ""
if patents:
first = patents[0]
qtype = first.get('query_type', 'N/A')
alpha = first.get('alpha_used', None)
squery = first.get('search_query', query)
config_info += f"📝 Query Type: {qtype}\n"
if alpha is not None:
config_info += f"⚖️ Hybrid Weight: TF-IDF {round(alpha*100)}% + LSA {round((1-alpha)*100)}%\n"
config_info += f"🔤 Functional Keywords: {squery}\n"
domain_info = f"📊 Domains: {source_domains}\n🎯 Cross-domain Targets: {target_domains}"
patents_info = "🔍 Cross-Domain Retrieval Results:\n"
for i, r in enumerate(patents, 1):
title = r.get("title", "N/A")[:55]
ipc = r.get("ipc", "N/A")
score = r.get("score", 0)
dtype = "same" if r.get("domain_type", "").startswith("same_") else "cross"
patents_info += f" {i}. [{dtype}] [{ipc}] {title} (score={score:.3f})\n"
concept_text = concept.get("design_text", "")
concept_info = f"💡 Design Concept ({concept.get('mode_name', 'Multi')}):\n\n{concept_text}"
return domain_info + "\n" + config_info + "\n" + patents_info, concept_info
except Exception as e:
return f"❌ Error: {str(e)}", ""
def run_pipeline(query: str, mode: str, top_k: int, data_choice: str, api_key: str):
"""Main pipeline: init + run"""
init_msg, kb, llm = init_system(data_choice, api_key)
if kb is None or llm is None:
return init_msg, ""
if not query.strip():
return "⚠️ Please enter a design problem", ""
if mode == "Single Mode (Single-Patent Deepening)":
return run_single_mode(query, top_k, kb, llm)
else:
return run_multi_mode(query, top_k, kb, llm)
# ---- Gradio UI ----
with gr.Blocks(title="PatentConceptor", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🔬 PatentConceptor
### Patent-Based Cross-Domain Design Concept Generation System
Enter a design problem in natural language, and the system retrieves cross-domain
patent knowledge to generate innovative product concepts.
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Configuration")
api_key_input = gr.Textbox(
label="SiliconFlow API Key",
placeholder="sk-... (or set SILICONFLOW_API_KEY env var)",
type="password",
value=_get_env_api_key(),
)
data_select = gr.Dropdown(
label="Patent Dataset",
choices=list(DATA_OPTIONS.keys()),
value="100K (Balanced, ~88MB)",
)
mode_select = gr.Radio(
label="Generation Mode",
choices=[
"Single Mode (Single-Patent Deepening)",
"Multi Mode (Cross-Domain Fusion)",
],
value="Single Mode (Single-Patent Deepening)",
)
top_k_slider = gr.Slider(
label="Top-K Patents",
minimum=5,
maximum=20,
value=10,
step=1,
)
query_input = gr.Textbox(
label="Design Problem",
placeholder="e.g., Wearable heart rate monitoring device",
lines=2,
)
run_btn = gr.Button("🚀 Generate Concept", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### 📊 Results")
retrieval_output = gr.Textbox(
label="Patent Retrieval & Domain Analysis",
lines=20,
interactive=False,
)
concept_output = gr.Textbox(
label="Generated Design Concept",
lines=15,
interactive=False,
)
gr.Markdown("""
---
**Paper**: *PatentConceptor: Generating innovative product concepts with patent knowledge and large language models*
**Dataset**: [DesignProblemDataset](https://huggingface.co/datasets/LiuHongwei1992/DesignProblemDatasetForPatentConceptorEvaluation) | **Patents**: [520634PatentsDataset](https://huggingface.co/datasets/LiuHongwei1992/520634PatentsDataset)
""")
run_btn.click(
fn=run_pipeline,
inputs=[query_input, mode_select, top_k_slider, data_select, api_key_input],
outputs=[retrieval_output, concept_output],
)
if __name__ == "__main__":
# 启用 Gradio Queue,支持长时间运行任务(绕过 HTTP 120s 超时)
demo.queue(default_concurrency_limit=1)
demo.launch()