Spaces:
Sleeping
Sleeping
| """ | |
| PatentConceptor - HuggingFace Spaces Gradio App | |
| Patent-Based Cross-Domain Design Concept Generation System | |
| """ | |
| import os | |
| import gradio as gr | |
| from pathlib import Path | |
| # ---- 配置 ---- | |
| # Online demo uses 10K for fast loading. | |
| # For local deployment with 100K/500K, download from the project repository | |
| # and place the CSV files in the data/ directory. | |
| DATA_OPTIONS = { | |
| "10K (Quick Test, ~9MB)": "patents_sample_10k.csv", | |
| } | |
| DEFAULT_DATA = "patents_sample_100k.csv" | |
| DATA_DIR = Path("data") | |
| # 全局缓存: {data_path: kb_instance} | |
| _kb_cache = {} | |
| _llm = None | |
| _current_data_file = None | |
| def _get_env_api_key(): | |
| return os.getenv("SILICONFLOW_API_KEY", "").strip() | |
| def init_system(data_choice: str, api_key: str): | |
| """初始化知识库和LLM客户端(带缓存,避免重复加载)""" | |
| global _kb_cache, _llm, _current_data_file | |
| import patent_kb_v2 as kb_mod | |
| from llm_client import SiliconFlowClient | |
| from patent_kb_v2 import PatentKnowledgeBaseV2 | |
| data_file = DATA_OPTIONS.get(data_choice, DEFAULT_DATA) | |
| data_path = DATA_DIR / data_file | |
| if not data_path.exists(): | |
| return ( | |
| f"❌ Data file not found: {data_path}\n" | |
| f"Please copy patent CSV files to the data/ directory.", | |
| None, None, | |
| ) | |
| key = api_key.strip() or _get_env_api_key() | |
| if not key: | |
| return "⚠️ Please provide a SiliconFlow API Key", None, None | |
| os.environ["SILICONFLOW_API_KEY"] = key | |
| # 检查是否已缓存相同数据集 | |
| cache_key = str(data_path) | |
| if cache_key in _kb_cache and _llm is not None: | |
| _current_data_file = data_file | |
| return ( | |
| f"✅ Using cached dataset\n" | |
| f" Dataset: {data_file}\n" | |
| f" Patents: {len(_kb_cache[cache_key].patents)}\n" | |
| f" API Key: configured", | |
| _kb_cache[cache_key], | |
| _llm, | |
| ) | |
| try: | |
| _llm = SiliconFlowClient(api_key=key) | |
| kb_mod._llm_client = _llm | |
| kb = PatentKnowledgeBaseV2(str(data_path)) | |
| kb.load() | |
| # 初始化 LSA + Hybrid Search Engine | |
| from semantic_search_sklearn import LSASemanticSearch, HybridSearchEngineLSA | |
| semantic_engine = LSASemanticSearch(n_components=100) | |
| semantic_engine.build_index(kb.patents) | |
| hybrid_engine = HybridSearchEngineLSA(kb, semantic_engine, alpha=0.6) | |
| hybrid_engine.set_llm_client(_llm) | |
| kb.set_hybrid_engine(hybrid_engine) | |
| _kb_cache[cache_key] = kb | |
| _current_data_file = data_file | |
| info = ( | |
| f"✅ System initialized (first load, may take 1-3 min for 500K)\n" | |
| f" Dataset: {data_file}\n" | |
| f" Patents: {len(kb.patents)}\n" | |
| f" Hybrid Search: enabled\n" | |
| f" API Key: configured" | |
| ) | |
| return info, kb, _llm | |
| except Exception as e: | |
| return f"❌ Initialization failed: {str(e)}", None, None | |
| def run_single_mode(query: str, top_k: int, kb, llm): | |
| """Single Mode: deepen one high-relevance patent (retrieval shows all domains)""" | |
| if not kb or not llm: | |
| return "⚠️ Please initialize the system first", "" | |
| try: | |
| # Domain analysis | |
| source_domains = kb.analyze_domains(query) | |
| target_domains = kb.get_cross_domain_targets(source_domains) | |
| # Hybrid search: source domains + cross-domain targets | |
| all_patents = [] | |
| for section in source_domains: | |
| try: | |
| results = kb.search_hybrid(query, section, top_k=top_k) | |
| for r in results: | |
| r["domain_type"] = f"same_{section}" | |
| all_patents.extend(results) | |
| except Exception as e: | |
| print(f"Error searching {section}: {e}") | |
| for section in target_domains[:3]: | |
| try: | |
| results = kb.search_hybrid(query, section, top_k=top_k) | |
| for r in results: | |
| r["domain_type"] = f"cross_{section}" | |
| r["score"] = r.get("score", 0) * 0.9 | |
| all_patents.extend(results) | |
| except Exception as e: | |
| print(f"Error searching cross {section}: {e}") | |
| # Deduplicate and sort | |
| seen = set() | |
| unique = [] | |
| for p in sorted(all_patents, key=lambda x: x.get("score", 0), reverse=True): | |
| title = p.get("title", "") | |
| if title and title not in seen: | |
| seen.add(title) | |
| unique.append(p) | |
| patents = unique[:top_k] | |
| # Generate concept using the top-1 patent (single mode semantics) | |
| top_patent = [patents[0]] if patents else [] | |
| concept = llm.generate_design_concept( | |
| query=query, | |
| patents=top_patent, | |
| source_domains=source_domains, | |
| target_domains=target_domains, | |
| mode="single", | |
| ) | |
| # Format output | |
| domain_info = f"📊 Domains: {source_domains}\n🎯 Cross-domain Targets: {target_domains}" | |
| # 提取查询配置信息(从首个结果中获取) | |
| config_info = "" | |
| if patents: | |
| first = patents[0] | |
| qtype = first.get('query_type', 'N/A') | |
| alpha = first.get('alpha_used', None) | |
| squery = first.get('search_query', query) | |
| config_info += f"📝 Query Type: {qtype}\n" | |
| if alpha is not None: | |
| config_info += f"⚖️ Hybrid Weight: TF-IDF {round(alpha*100)}% + LSA {round((1-alpha)*100)}%\n" | |
| config_info += f"🔤 Functional Keywords: {squery}\n" | |
| same_count = sum(1 for p in patents if p.get("domain_type", "").startswith("same_")) | |
| cross_count = sum(1 for p in patents if p.get("domain_type", "").startswith("cross_")) | |
| patents_info = f"🔍 Retrieved Patents (same: {same_count}, cross: {cross_count}):\n" | |
| for i, r in enumerate(patents, 1): | |
| title = r.get("title", "N/A")[:55] | |
| ipc = r.get("ipc", "N/A") | |
| score = r.get("score", 0) | |
| dtype = "same" if r.get("domain_type", "").startswith("same_") else "cross" | |
| patents_info += f" {i}. [{dtype}] [{ipc}] {title} (score={score:.3f})\n" | |
| concept_text = concept.get("design_text", "") | |
| concept_info = f"💡 Design Concept ({concept.get('mode_name', 'Single')}):\n\n{concept_text}" | |
| return domain_info + "\n" + config_info + "\n" + patents_info, concept_info | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}", "" | |
| def run_multi_mode(query: str, top_k: int, kb, llm): | |
| """Multi Mode: fuse cross-domain patents""" | |
| if not kb or not llm: | |
| return "⚠️ Please initialize the system first", "" | |
| try: | |
| # Domain analysis | |
| source_domains = kb.analyze_domains(query) | |
| target_domains = kb.get_cross_domain_targets(source_domains) | |
| # Hybrid search: source + cross-domain | |
| all_patents = [] | |
| for section in source_domains: | |
| try: | |
| results = kb.search_hybrid(query, section, top_k=top_k) | |
| for r in results: | |
| r["domain_type"] = f"same_{section}" | |
| all_patents.extend(results) | |
| except Exception as e: | |
| print(f"Error searching {section}: {e}") | |
| for section in target_domains[:3]: | |
| try: | |
| results = kb.search_hybrid(query, section, top_k=top_k) | |
| for r in results: | |
| r["domain_type"] = f"cross_{section}" | |
| r["score"] = r.get("score", 0) * 0.9 | |
| all_patents.extend(results) | |
| except Exception as e: | |
| print(f"Error searching cross {section}: {e}") | |
| # Separate and deduplicate | |
| same = [p for p in all_patents if p.get("domain_type", "").startswith("same_")] | |
| cross = [p for p in all_patents if p.get("domain_type", "").startswith("cross_")] | |
| seen_same = set() | |
| unique_same = [] | |
| for p in sorted(same, key=lambda x: x.get("score", 0), reverse=True): | |
| title = p.get("title", "") | |
| if title and title not in seen_same: | |
| seen_same.add(title) | |
| unique_same.append(p) | |
| seen_cross = set() | |
| unique_cross = [] | |
| for p in sorted(cross, key=lambda x: x.get("score", 0), reverse=True): | |
| title = p.get("title", "") | |
| if title and title not in seen_cross: | |
| seen_cross.add(title) | |
| unique_cross.append(p) | |
| patents = unique_same[:4] + unique_cross[:4] | |
| if len(unique_cross) < 2: | |
| patents = unique_same[:6] + unique_cross[:2] | |
| # Generate concept | |
| concept = llm.generate_design_concept( | |
| query=query, | |
| patents=patents, | |
| source_domains=source_domains, | |
| target_domains=target_domains, | |
| mode="multi", | |
| ) | |
| # 提取查询配置信息 | |
| config_info = "" | |
| if patents: | |
| first = patents[0] | |
| qtype = first.get('query_type', 'N/A') | |
| alpha = first.get('alpha_used', None) | |
| squery = first.get('search_query', query) | |
| config_info += f"📝 Query Type: {qtype}\n" | |
| if alpha is not None: | |
| config_info += f"⚖️ Hybrid Weight: TF-IDF {round(alpha*100)}% + LSA {round((1-alpha)*100)}%\n" | |
| config_info += f"🔤 Functional Keywords: {squery}\n" | |
| domain_info = f"📊 Domains: {source_domains}\n🎯 Cross-domain Targets: {target_domains}" | |
| patents_info = "🔍 Cross-Domain Retrieval Results:\n" | |
| for i, r in enumerate(patents, 1): | |
| title = r.get("title", "N/A")[:55] | |
| ipc = r.get("ipc", "N/A") | |
| score = r.get("score", 0) | |
| dtype = "same" if r.get("domain_type", "").startswith("same_") else "cross" | |
| patents_info += f" {i}. [{dtype}] [{ipc}] {title} (score={score:.3f})\n" | |
| concept_text = concept.get("design_text", "") | |
| concept_info = f"💡 Design Concept ({concept.get('mode_name', 'Multi')}):\n\n{concept_text}" | |
| return domain_info + "\n" + config_info + "\n" + patents_info, concept_info | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}", "" | |
| def run_pipeline(query: str, mode: str, top_k: int, data_choice: str, api_key: str): | |
| """Main pipeline: init + run""" | |
| init_msg, kb, llm = init_system(data_choice, api_key) | |
| if kb is None or llm is None: | |
| return init_msg, "" | |
| if not query.strip(): | |
| return "⚠️ Please enter a design problem", "" | |
| if mode == "Single Mode (Single-Patent Deepening)": | |
| return run_single_mode(query, top_k, kb, llm) | |
| else: | |
| return run_multi_mode(query, top_k, kb, llm) | |
| # ---- Gradio UI ---- | |
| with gr.Blocks(title="PatentConceptor", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🔬 PatentConceptor | |
| ### Patent-Based Cross-Domain Design Concept Generation System | |
| Enter a design problem in natural language, and the system retrieves cross-domain | |
| patent knowledge to generate innovative product concepts. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ Configuration") | |
| api_key_input = gr.Textbox( | |
| label="SiliconFlow API Key", | |
| placeholder="sk-... (or set SILICONFLOW_API_KEY env var)", | |
| type="password", | |
| value=_get_env_api_key(), | |
| ) | |
| data_select = gr.Dropdown( | |
| label="Patent Dataset", | |
| choices=list(DATA_OPTIONS.keys()), | |
| value="100K (Balanced, ~88MB)", | |
| ) | |
| mode_select = gr.Radio( | |
| label="Generation Mode", | |
| choices=[ | |
| "Single Mode (Single-Patent Deepening)", | |
| "Multi Mode (Cross-Domain Fusion)", | |
| ], | |
| value="Single Mode (Single-Patent Deepening)", | |
| ) | |
| top_k_slider = gr.Slider( | |
| label="Top-K Patents", | |
| minimum=5, | |
| maximum=20, | |
| value=10, | |
| step=1, | |
| ) | |
| query_input = gr.Textbox( | |
| label="Design Problem", | |
| placeholder="e.g., Wearable heart rate monitoring device", | |
| lines=2, | |
| ) | |
| run_btn = gr.Button("🚀 Generate Concept", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📊 Results") | |
| retrieval_output = gr.Textbox( | |
| label="Patent Retrieval & Domain Analysis", | |
| lines=20, | |
| interactive=False, | |
| ) | |
| concept_output = gr.Textbox( | |
| label="Generated Design Concept", | |
| lines=15, | |
| interactive=False, | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Paper**: *PatentConceptor: Generating innovative product concepts with patent knowledge and large language models* | |
| **Dataset**: [DesignProblemDataset](https://huggingface.co/datasets/LiuHongwei1992/DesignProblemDatasetForPatentConceptorEvaluation) | **Patents**: [520634PatentsDataset](https://huggingface.co/datasets/LiuHongwei1992/520634PatentsDataset) | |
| """) | |
| run_btn.click( | |
| fn=run_pipeline, | |
| inputs=[query_input, mode_select, top_k_slider, data_select, api_key_input], | |
| outputs=[retrieval_output, concept_output], | |
| ) | |
| if __name__ == "__main__": | |
| # 启用 Gradio Queue,支持长时间运行任务(绕过 HTTP 120s 超时) | |
| demo.queue(default_concurrency_limit=1) | |
| demo.launch() | |