""" PatentConceptor - HuggingFace Spaces Gradio App Patent-Based Cross-Domain Design Concept Generation System """ import os import gradio as gr from pathlib import Path # ---- 配置 ---- # Online demo uses 10K for fast loading. # For local deployment with 100K/500K, download from the project repository # and place the CSV files in the data/ directory. DATA_OPTIONS = { "10K (Quick Test, ~9MB)": "patents_sample_10k.csv", } DEFAULT_DATA = "patents_sample_100k.csv" DATA_DIR = Path("data") # 全局缓存: {data_path: kb_instance} _kb_cache = {} _llm = None _current_data_file = None def _get_env_api_key(): return os.getenv("SILICONFLOW_API_KEY", "").strip() def init_system(data_choice: str, api_key: str): """初始化知识库和LLM客户端(带缓存,避免重复加载)""" global _kb_cache, _llm, _current_data_file import patent_kb_v2 as kb_mod from llm_client import SiliconFlowClient from patent_kb_v2 import PatentKnowledgeBaseV2 data_file = DATA_OPTIONS.get(data_choice, DEFAULT_DATA) data_path = DATA_DIR / data_file if not data_path.exists(): return ( f"❌ Data file not found: {data_path}\n" f"Please copy patent CSV files to the data/ directory.", None, None, ) key = api_key.strip() or _get_env_api_key() if not key: return "⚠️ Please provide a SiliconFlow API Key", None, None os.environ["SILICONFLOW_API_KEY"] = key # 检查是否已缓存相同数据集 cache_key = str(data_path) if cache_key in _kb_cache and _llm is not None: _current_data_file = data_file return ( f"✅ Using cached dataset\n" f" Dataset: {data_file}\n" f" Patents: {len(_kb_cache[cache_key].patents)}\n" f" API Key: configured", _kb_cache[cache_key], _llm, ) try: _llm = SiliconFlowClient(api_key=key) kb_mod._llm_client = _llm kb = PatentKnowledgeBaseV2(str(data_path)) kb.load() # 初始化 LSA + Hybrid Search Engine from semantic_search_sklearn import LSASemanticSearch, HybridSearchEngineLSA semantic_engine = LSASemanticSearch(n_components=100) semantic_engine.build_index(kb.patents) hybrid_engine = HybridSearchEngineLSA(kb, semantic_engine, alpha=0.6) hybrid_engine.set_llm_client(_llm) kb.set_hybrid_engine(hybrid_engine) _kb_cache[cache_key] = kb _current_data_file = data_file info = ( f"✅ System initialized (first load, may take 1-3 min for 500K)\n" f" Dataset: {data_file}\n" f" Patents: {len(kb.patents)}\n" f" Hybrid Search: enabled\n" f" API Key: configured" ) return info, kb, _llm except Exception as e: return f"❌ Initialization failed: {str(e)}", None, None def run_single_mode(query: str, top_k: int, kb, llm): """Single Mode: deepen one high-relevance patent (retrieval shows all domains)""" if not kb or not llm: return "⚠️ Please initialize the system first", "" try: # Domain analysis source_domains = kb.analyze_domains(query) target_domains = kb.get_cross_domain_targets(source_domains) # Hybrid search: source domains + cross-domain targets all_patents = [] for section in source_domains: try: results = kb.search_hybrid(query, section, top_k=top_k) for r in results: r["domain_type"] = f"same_{section}" all_patents.extend(results) except Exception as e: print(f"Error searching {section}: {e}") for section in target_domains[:3]: try: results = kb.search_hybrid(query, section, top_k=top_k) for r in results: r["domain_type"] = f"cross_{section}" r["score"] = r.get("score", 0) * 0.9 all_patents.extend(results) except Exception as e: print(f"Error searching cross {section}: {e}") # Deduplicate and sort seen = set() unique = [] for p in sorted(all_patents, key=lambda x: x.get("score", 0), reverse=True): title = p.get("title", "") if title and title not in seen: seen.add(title) unique.append(p) patents = unique[:top_k] # Generate concept using the top-1 patent (single mode semantics) top_patent = [patents[0]] if patents else [] concept = llm.generate_design_concept( query=query, patents=top_patent, source_domains=source_domains, target_domains=target_domains, mode="single", ) # Format output domain_info = f"📊 Domains: {source_domains}\n🎯 Cross-domain Targets: {target_domains}" # 提取查询配置信息(从首个结果中获取) config_info = "" if patents: first = patents[0] qtype = first.get('query_type', 'N/A') alpha = first.get('alpha_used', None) squery = first.get('search_query', query) config_info += f"📝 Query Type: {qtype}\n" if alpha is not None: config_info += f"⚖️ Hybrid Weight: TF-IDF {round(alpha*100)}% + LSA {round((1-alpha)*100)}%\n" config_info += f"🔤 Functional Keywords: {squery}\n" same_count = sum(1 for p in patents if p.get("domain_type", "").startswith("same_")) cross_count = sum(1 for p in patents if p.get("domain_type", "").startswith("cross_")) patents_info = f"🔍 Retrieved Patents (same: {same_count}, cross: {cross_count}):\n" for i, r in enumerate(patents, 1): title = r.get("title", "N/A")[:55] ipc = r.get("ipc", "N/A") score = r.get("score", 0) dtype = "same" if r.get("domain_type", "").startswith("same_") else "cross" patents_info += f" {i}. [{dtype}] [{ipc}] {title} (score={score:.3f})\n" concept_text = concept.get("design_text", "") concept_info = f"💡 Design Concept ({concept.get('mode_name', 'Single')}):\n\n{concept_text}" return domain_info + "\n" + config_info + "\n" + patents_info, concept_info except Exception as e: return f"❌ Error: {str(e)}", "" def run_multi_mode(query: str, top_k: int, kb, llm): """Multi Mode: fuse cross-domain patents""" if not kb or not llm: return "⚠️ Please initialize the system first", "" try: # Domain analysis source_domains = kb.analyze_domains(query) target_domains = kb.get_cross_domain_targets(source_domains) # Hybrid search: source + cross-domain all_patents = [] for section in source_domains: try: results = kb.search_hybrid(query, section, top_k=top_k) for r in results: r["domain_type"] = f"same_{section}" all_patents.extend(results) except Exception as e: print(f"Error searching {section}: {e}") for section in target_domains[:3]: try: results = kb.search_hybrid(query, section, top_k=top_k) for r in results: r["domain_type"] = f"cross_{section}" r["score"] = r.get("score", 0) * 0.9 all_patents.extend(results) except Exception as e: print(f"Error searching cross {section}: {e}") # Separate and deduplicate same = [p for p in all_patents if p.get("domain_type", "").startswith("same_")] cross = [p for p in all_patents if p.get("domain_type", "").startswith("cross_")] seen_same = set() unique_same = [] for p in sorted(same, key=lambda x: x.get("score", 0), reverse=True): title = p.get("title", "") if title and title not in seen_same: seen_same.add(title) unique_same.append(p) seen_cross = set() unique_cross = [] for p in sorted(cross, key=lambda x: x.get("score", 0), reverse=True): title = p.get("title", "") if title and title not in seen_cross: seen_cross.add(title) unique_cross.append(p) patents = unique_same[:4] + unique_cross[:4] if len(unique_cross) < 2: patents = unique_same[:6] + unique_cross[:2] # Generate concept concept = llm.generate_design_concept( query=query, patents=patents, source_domains=source_domains, target_domains=target_domains, mode="multi", ) # 提取查询配置信息 config_info = "" if patents: first = patents[0] qtype = first.get('query_type', 'N/A') alpha = first.get('alpha_used', None) squery = first.get('search_query', query) config_info += f"📝 Query Type: {qtype}\n" if alpha is not None: config_info += f"⚖️ Hybrid Weight: TF-IDF {round(alpha*100)}% + LSA {round((1-alpha)*100)}%\n" config_info += f"🔤 Functional Keywords: {squery}\n" domain_info = f"📊 Domains: {source_domains}\n🎯 Cross-domain Targets: {target_domains}" patents_info = "🔍 Cross-Domain Retrieval Results:\n" for i, r in enumerate(patents, 1): title = r.get("title", "N/A")[:55] ipc = r.get("ipc", "N/A") score = r.get("score", 0) dtype = "same" if r.get("domain_type", "").startswith("same_") else "cross" patents_info += f" {i}. [{dtype}] [{ipc}] {title} (score={score:.3f})\n" concept_text = concept.get("design_text", "") concept_info = f"💡 Design Concept ({concept.get('mode_name', 'Multi')}):\n\n{concept_text}" return domain_info + "\n" + config_info + "\n" + patents_info, concept_info except Exception as e: return f"❌ Error: {str(e)}", "" def run_pipeline(query: str, mode: str, top_k: int, data_choice: str, api_key: str): """Main pipeline: init + run""" init_msg, kb, llm = init_system(data_choice, api_key) if kb is None or llm is None: return init_msg, "" if not query.strip(): return "⚠️ Please enter a design problem", "" if mode == "Single Mode (Single-Patent Deepening)": return run_single_mode(query, top_k, kb, llm) else: return run_multi_mode(query, top_k, kb, llm) # ---- Gradio UI ---- with gr.Blocks(title="PatentConceptor", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🔬 PatentConceptor ### Patent-Based Cross-Domain Design Concept Generation System Enter a design problem in natural language, and the system retrieves cross-domain patent knowledge to generate innovative product concepts. """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### ⚙️ Configuration") api_key_input = gr.Textbox( label="SiliconFlow API Key", placeholder="sk-... (or set SILICONFLOW_API_KEY env var)", type="password", value=_get_env_api_key(), ) data_select = gr.Dropdown( label="Patent Dataset", choices=list(DATA_OPTIONS.keys()), value="100K (Balanced, ~88MB)", ) mode_select = gr.Radio( label="Generation Mode", choices=[ "Single Mode (Single-Patent Deepening)", "Multi Mode (Cross-Domain Fusion)", ], value="Single Mode (Single-Patent Deepening)", ) top_k_slider = gr.Slider( label="Top-K Patents", minimum=5, maximum=20, value=10, step=1, ) query_input = gr.Textbox( label="Design Problem", placeholder="e.g., Wearable heart rate monitoring device", lines=2, ) run_btn = gr.Button("🚀 Generate Concept", variant="primary") with gr.Column(scale=2): gr.Markdown("### 📊 Results") retrieval_output = gr.Textbox( label="Patent Retrieval & Domain Analysis", lines=20, interactive=False, ) concept_output = gr.Textbox( label="Generated Design Concept", lines=15, interactive=False, ) gr.Markdown(""" --- **Paper**: *PatentConceptor: Generating innovative product concepts with patent knowledge and large language models* **Dataset**: [DesignProblemDataset](https://huggingface.co/datasets/LiuHongwei1992/DesignProblemDatasetForPatentConceptorEvaluation) | **Patents**: [520634PatentsDataset](https://huggingface.co/datasets/LiuHongwei1992/520634PatentsDataset) """) run_btn.click( fn=run_pipeline, inputs=[query_input, mode_select, top_k_slider, data_select, api_key_input], outputs=[retrieval_output, concept_output], ) if __name__ == "__main__": # 启用 Gradio Queue,支持长时间运行任务(绕过 HTTP 120s 超时) demo.queue(default_concurrency_limit=1) demo.launch()