Spaces:
Sleeping
Sleeping
| import time | |
| import asyncio | |
| import traceback | |
| from typing import List, Dict, Any, Optional, Callable, Tuple | |
| from langsmith import traceable | |
| try: | |
| import config | |
| from services import retriever, openai_service | |
| except ImportError: | |
| print("Error: Failed to import config or services in rag_processor.py") | |
| raise SystemExit("Failed imports in rag_processor.py") | |
| PIPELINE_VALIDATE_GENERATE_GPT4O = "GPT-4o Validator + GPT-4o Synthesizer" | |
| StatusCallback = Callable[[str], None] | |
| # --- Step Functions --- | |
| async def run_retrieval_step(query: str, n_retrieve: int, update_status: StatusCallback) -> List[Dict]: | |
| update_status(f"1. 诪讗讞讝专 注讚 {n_retrieve} 驻住拽讗讜转 诪-Pinecone...") | |
| start_time = time.time() | |
| retrieved_docs = retriever.retrieve_documents(query_text=query, n_results=n_retrieve) | |
| retrieval_time = time.time() - start_time | |
| status_msg = f"讗讜讞讝专讜 {len(retrieved_docs)} 驻住拽讗讜转 讘-{retrieval_time:.2f} 砖谞讬讜转." | |
| update_status(f"1. {status_msg}") | |
| if not retrieved_docs: | |
| update_status("1. 诇讗 讗讜转专讜 诪住诪讻讬诐.") | |
| return retrieved_docs | |
| async def run_gpt4o_validation_filter_step( | |
| docs_to_process: List[Dict], query: str, n_validate: int, update_status: StatusCallback | |
| ) -> List[Dict]: | |
| if not docs_to_process: | |
| update_status("2. [GPT-4o] 讚讬诇讜讙 注诇 讗讬诪讜转 - 讗讬谉 驻住拽讗讜转.") | |
| return [] | |
| validation_count = min(len(docs_to_process), n_validate) | |
| update_status(f"2. [GPT-4o] 诪转讞讬诇 讗讬诪讜转 诪拽讘讬诇讬 ({validation_count} / {len(docs_to_process)} 驻住拽讗讜转)...") | |
| validation_start_time = time.time() | |
| tasks = [openai_service.validate_relevance_openai(doc, query, i) | |
| for i, doc in enumerate(docs_to_process[:validation_count])] | |
| validation_results = await asyncio.gather(*tasks, return_exceptions=True) | |
| passed_docs = [] | |
| passed_count = failed_validation_count = error_count = 0 | |
| update_status("3. [GPT-4o] 住讬谞讜谉 驻住拽讗讜转 诇驻讬 转讜爪讗讜转 讗讬诪讜转...") | |
| for i, res in enumerate(validation_results): | |
| original_doc = docs_to_process[i] | |
| if isinstance(res, Exception): | |
| print(f"GPT-4o Validation Exception doc {i}: {res}") | |
| error_count += 1 | |
| elif isinstance(res, dict) and 'validation' in res: | |
| if res['validation'].get('contains_relevant_info'): | |
| original_doc['validation_result'] = res['validation'] | |
| passed_docs.append(original_doc) | |
| passed_count += 1 | |
| else: | |
| failed_validation_count += 1 | |
| else: | |
| print(f"GPT-4o Validation Unexpected result doc {i}: {type(res)}") | |
| error_count += 1 | |
| validation_time = time.time() - validation_start_time | |
| status_msg_val = (f"讗讬诪讜转 GPT-4o 讛讜砖诇诐 ({passed_count} 注讘专讜, " | |
| f"{failed_validation_count} 谞讚讞讜, {error_count} 砖讙讬讗讜转) " | |
| f"讘-{validation_time:.2f} 砖谞讬讜转.") | |
| update_status(f"2. {status_msg_val}") | |
| status_msg_filter = f"谞讗住驻讜 {len(passed_docs)} 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讗讞专 讗讬诪讜转 GPT-4o." | |
| update_status(f"3. {status_msg_filter}") | |
| return passed_docs | |
| async def run_openai_generation_step( | |
| history: List[Dict], context_documents: List[Dict], | |
| update_status: StatusCallback, stream_callback: Callable[[str], None] | |
| ) -> Tuple[str, Optional[str]]: | |
| generator_name = "OpenAI" | |
| if not context_documents: | |
| update_status(f"4. [{generator_name}] 讚讬诇讜讙 注诇 讬爪讬专讛 - 讗讬谉 驻住拽讗讜转 诇讛拽砖专.") | |
| return "诇讗 住讜驻拽讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讬爪讬专转 讛转砖讜讘讛.", None | |
| update_status(f"4. [{generator_name}] 诪讞讜诇诇 转砖讜讘讛 住讜驻讬转 诪-{len(context_documents)} 拽讟注讬 讛拽砖专...") | |
| start_gen_time = time.time() | |
| try: | |
| full_response = [] | |
| error_msg = None | |
| generator = openai_service.generate_openai_stream( | |
| messages=history, context_documents=context_documents | |
| ) | |
| async for chunk in generator: | |
| if isinstance(chunk, str) and chunk.strip().startswith("--- Error:"): | |
| if not error_msg: | |
| error_msg = chunk.strip() | |
| print(f"OpenAI stream yielded error: {chunk.strip()}") | |
| break | |
| if isinstance(chunk, str): | |
| full_response.append(chunk) | |
| stream_callback(chunk) | |
| final_response_text = "".join(full_response) | |
| gen_time = time.time() - start_gen_time | |
| if error_msg: | |
| update_status(f"4. 砖讙讬讗讛 讘讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讘-{gen_time:.2f} 砖谞讬讜转.") | |
| return final_response_text, error_msg | |
| update_status(f"4. 讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讛讜砖诇诪讛 讘-{gen_time:.2f} 砖谞讬讜转.") | |
| return final_response_text, None | |
| except Exception as gen_err: | |
| gen_time = time.time() - start_gen_time | |
| error_msg_critical = (f"--- Error: Critical failure during {generator_name} generation " | |
| f"({type(gen_err).__name__}): {gen_err} ---") | |
| update_status(f"4. 砖讙讬讗讛 拽专讬讟讬转 讘讬爪讬专转 讛转砖讜讘讛 ({generator_name}) 讘-{gen_time:.2f} 砖谞讬讜转.") | |
| traceback.print_exc() | |
| return "", error_msg_critical | |
| async def execute_validate_generate_pipeline( | |
| history: List[Dict], params: Dict[str, Any], | |
| status_callback: StatusCallback, stream_callback: Callable[[str], None] | |
| ) -> Dict[str, Any]: | |
| result: Dict[str, Any] = { | |
| "final_response": "", | |
| "validated_documents_full": [], | |
| "generator_input_documents": [], | |
| "status_log": [], | |
| "error": None, | |
| "pipeline_used": PIPELINE_VALIDATE_GENERATE_GPT4O | |
| } | |
| status_log_internal: List[str] = [] | |
| def update_status_and_log(message: str): | |
| print(f"Status Update: {message}") | |
| status_log_internal.append(message) | |
| status_callback(message) | |
| current_query_text = "" | |
| if history and isinstance(history, list): | |
| for msg_ in reversed(history): | |
| if isinstance(msg_, dict) and msg_.get("role") == "user": | |
| current_query_text = str(msg_.get("content") or "") | |
| break | |
| if not current_query_text: | |
| result["error"] = "诇讗 讝讜讛转讛 砖讗诇讛." | |
| result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>" | |
| result["status_log"] = status_log_internal | |
| return result | |
| try: | |
| # 1. Retrieval | |
| retrieved_docs = await run_retrieval_step( | |
| current_query_text, params['n_retrieve'], update_status_and_log | |
| ) | |
| if not retrieved_docs: | |
| result["error"] = "诇讗 讗讜转专讜 诪拽讜专讜转." | |
| result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>" | |
| result["status_log"] = status_log_internal | |
| return result | |
| # 2. Validation | |
| validated_docs_full = await run_gpt4o_validation_filter_step( | |
| retrieved_docs, current_query_text, params['n_validate'], update_status_and_log | |
| ) | |
| result["validated_documents_full"] = validated_docs_full | |
| if not validated_docs_full: | |
| result["error"] = "诇讗 谞诪爪讗讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转." | |
| result["final_response"] = f"<div class='rtl-text'>{result['error']}</div>" | |
| update_status_and_log(f"4. {result['error']} 诇讗 谞讬转谉 诇讛诪砖讬讱.") | |
| return result | |
| # --- Simplify Docs for Generation --- | |
| simplified_docs_for_generation: List[Dict[str, Any]] = [] | |
| print(f"Processor: Simplifying {len(validated_docs_full)} docs...") | |
| for doc in validated_docs_full: | |
| if isinstance(doc, dict): | |
| hebrew_text = doc.get('hebrew_text', '') | |
| validation = doc.get('validation_result') | |
| if hebrew_text: | |
| simplified_doc: Dict[str, Any] = { | |
| 'hebrew_text': hebrew_text, | |
| 'original_id': doc.get('original_id', 'unknown') | |
| } | |
| if doc.get('source_name'): | |
| simplified_doc['source_name'] = doc.get('source_name') | |
| if validation is not None: | |
| simplified_doc['validation_result'] = validation # include judgment | |
| simplified_docs_for_generation.append(simplified_doc) | |
| else: | |
| print(f"Warn: Skipping non-dict item: {doc}") | |
| result["generator_input_documents"] = simplified_docs_for_generation | |
| print(f"Processor: Created {len(simplified_docs_for_generation)} simplified docs with validation results.") | |
| # 3. Generation | |
| final_response_text, generation_error = await run_openai_generation_step( | |
| history=history, | |
| context_documents=simplified_docs_for_generation, | |
| update_status=update_status_and_log, | |
| stream_callback=stream_callback | |
| ) | |
| result["final_response"] = final_response_text | |
| result["error"] = generation_error | |
| if generation_error and not result["final_response"].strip().startswith(("<div", "诇讗 住讜驻拽讜")): | |
| result["final_response"] = ( | |
| f"<div class='rtl-text'><strong>砖讙讬讗讛 讘讬爪讬专转 讛转砖讜讘讛.</strong><br>" | |
| f"驻专讟讬诐: {generation_error}<br>---<br>{result['final_response']}</div>" | |
| ) | |
| elif result["final_response"] == "诇讗 住讜驻拽讜 驻住拽讗讜转 专诇讜讜谞讟讬讜转 诇讬爪讬专转 讛转砖讜讘讛.": | |
| result["final_response"] = f"<div class='rtl-text'>{result['final_response']}</div>" | |
| except Exception as e: | |
| error_type = type(e).__name__ | |
| error_msg = f"砖讙讬讗讛 拽专讬讟讬转 RAG ({error_type}): {e}" | |
| print(f"Critical RAG Error: {error_msg}") | |
| traceback.print_exc() | |
| result["error"] = error_msg | |
| result["final_response"] = ( | |
| f"<div class='rtl-text'><strong>砖讙讬讗讛 拽专讬讟讬转! ({error_type})</strong><br>谞住讛 砖讜讘." | |
| f"<details><summary>驻专讟讬诐</summary><pre>{traceback.format_exc()}</pre></details></div>" | |
| ) | |
| update_status_and_log(f"砖讙讬讗讛 拽专讬讟讬转: {error_type}") | |
| result["status_log"] = status_log_internal | |
| return result | |