Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import base64 | |
| import requests | |
| from pathlib import Path | |
| from typing import List, Dict, Any | |
| from dotenv import load_dotenv | |
| import asyncio | |
| # Load environment variables from a .env file | |
| load_dotenv() | |
| # --- Configuration --- | |
| KB_BASE_PATH = Path("kb") | |
| JSON_INFO_PATH = KB_BASE_PATH / "product_info" | |
| PROCESSED_IMAGES_PATH = KB_BASE_PATH / "processed_images" | |
| OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") | |
| def _make_llm_call(payload: Dict[str, Any], timeout: int = 60) -> Dict[str, Any]: | |
| headers = {"Authorization": f"Bearer {OPENROUTER_API_KEY}"} | |
| response = requests.post("https://openrouter.ai/api/v1/chat/completions", headers=headers, json=payload, timeout=timeout) | |
| response.raise_for_status() | |
| return response.json() | |
| def _get_relevant_machines_from_llm(query: str, machine_map: Dict[str, Any]) -> List[str]: | |
| print("-> [Router] Calling Router LLM to identify relevant machines...") | |
| system_prompt = f""" | |
| You are an expert router assistant. Your task is to analyze a user's query and a list of available machines. | |
| You must identify which machines are relevant. Respond ONLY with a valid JSON object: {{"machines": ["machine_id_1", "machine_id_2"]}}. | |
| Here are examples of how to behave: | |
| - User Query: "Tell me about the Cantek JDT75" -> {{"machines": ["machine1"]}} | |
| - User Query: "Compare the JDT75 and the DT65" -> {{"machines": ["machine1", "machine4"]}} | |
| - User Query: "Which machines have a 2HP motor?" OR "I need a dovetailer for a small shop" -> {{"machines": {list(machine_map.keys())}}} | |
| - User Query: "Do you sell hammers?" -> {{"machines": []}} | |
| Core Rules: | |
| 1. If the user asks about specific machines, return their IDs. | |
| 2. If the user asks a general question about features, capabilities, or recommendations, you MUST return a list of ALL available machine IDs to enable a full search. | |
| 3. If the query is irrelevant, return an empty list. | |
| """ | |
| user_prompt_content = f"User Query: \"{query}\"\n\nAvailable Machines:\n{json.dumps(machine_map, indent=2)}" | |
| payload = { | |
| "model": "openai/gpt-4o", | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt_content} | |
| ], | |
| "response_format": {"type": "json_object"} | |
| } | |
| try: | |
| response_json = _make_llm_call(payload) | |
| machine_ids = json.loads(response_json['choices'][0]['message']['content']).get("machines", []) | |
| print(f" -> [Router] Identified: {machine_ids}") | |
| return machine_ids | |
| except Exception as e: | |
| print(f" ! [Router] Error: {e}. Defaulting to all machines.") | |
| return list(machine_map.keys()) | |
| def _batch_is_json_sufficient(query: str, machine_json_map: Dict[str, Any]) -> Dict[str, str]: | |
| """ | |
| Returns a dict: {machine_id: "YES"/"NO"} | |
| """ | |
| print(" -> [Tier 1] Batch checking if JSON data is sufficient for all machines...") | |
| system_prompt = ( | |
| "You are a strict, binary classification bot. Your only job is to determine if a user query REQUIRES looking at an image. Follow these two rules in order." | |
| "\n\n**Rule #1: Visual Query Check (Highest Priority)**" | |
| "First, scan the user's query for any words or phrases that imply visual inspection. If the query contains words like `show me`, `look like`, `see`, `color`, `image`, `picture`, `appearance`, `button`, `handle`, `panel`, or any other term that asks for a visual description, you MUST answer **NO**." | |
| "\n\n**Rule #2: Data Query Check (Only if Rule #1 does not apply)**" | |
| "If the query is NOT visual, your answer is **YES**, as long as the JSON contains the relevant data key. For a price query ('how much?', 'under 10k'), if the JSON has a 'price' key, the answer is **YES**." | |
| "\n\n**Default Behavior:**" | |
| "Your default position is that the JSON is sufficient ('YES') unless the query explicitly forces a visual check ('NO'). Do not invent reasons to need images for questions about price or technical data." | |
| "\n\n**Output Format:**" | |
| "Respond ONLY with a valid JSON object mapping each machine ID to 'YES' or 'NO'. Example: {\"machine1\": \"YES\", \"machine2\": \"NO\"}" | |
| ) | |
| user_prompt_content = f"Query: \"{query}\"\n\nMachine JSONs:\n{json.dumps(machine_json_map, indent=2)}" | |
| payload = { | |
| "model": "openai/gpt-4o", | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt_content} | |
| ], | |
| "response_format": {"type": "json_object"} | |
| } | |
| try: | |
| response_json = _make_llm_call(payload, timeout=60) | |
| result = json.loads(response_json['choices'][0]['message']['content']) | |
| print(f" -> [Tier 1] Batch Decision: {result}") | |
| return result | |
| except Exception as e: | |
| print(f"Batch JSON sufficiency check error: {e}") | |
| # Default: assume all are not sufficient | |
| return {mid: "NO" for mid in machine_json_map} | |
| def _get_targeted_images(query: str, image_descriptions: List[Dict]) -> List[str]: | |
| print(" -> [Tier 2] Identifying targeted images for analysis...") | |
| system_prompt = "Analyze the user query and the list of image descriptions. Which specific images are most likely to contain the answer? Respond ONLY with a JSON object: {\"images\": [\"filename1.png\", \"filename2.png\"]}." | |
| user_prompt_content = f"Query: \"{query}\"\n\nImage Descriptions:\n{json.dumps(image_descriptions, indent=2)}" | |
| payload = { | |
| "model": "openai/gpt-4o", | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt_content} | |
| ], | |
| "response_format": {"type": "json_object"} | |
| } | |
| try: | |
| response_json = _make_llm_call(payload) | |
| image_names = json.loads(response_json['choices'][0]['message']['content']).get("images", []) | |
| print(f" -> [Tier 2] Selected images: {image_names}") | |
| return image_names | |
| except Exception as e: | |
| print(f" ! [Tier 2] Error: {e}. Proceeding to full analysis.") | |
| return [] | |
| def _analyze_images_batch(image_paths: List[Path], query: str) -> str: | |
| """ | |
| Sends all images for a machine in a single GPT-4o call for batch analysis. | |
| """ | |
| print(f" -> [Batch Analysis] Sending {len(image_paths)} images for batch analysis.") | |
| images_payload = [ | |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64.b64encode(img.read_bytes()).decode('utf-8')}"}} | |
| for img in image_paths | |
| ] | |
| payload = { | |
| "model": "openai/gpt-4o", | |
| "messages": [ | |
| {"role": "user", "content": [ | |
| {"type": "text", "text": f"Analyze ALL of these images for the following machine to answer: '{query}'. If any image shows the requested feature, specify which image(s) and describe what you see. If none, say so."}, | |
| *images_payload | |
| ]} | |
| ] | |
| } | |
| try: | |
| response_json = _make_llm_call(payload, timeout=120) | |
| return response_json['choices'][0]['message']['content'] | |
| except Exception as e: | |
| return f"API Error during batch analysis: {e}" | |
| def _get_context_for_machine(query: str, machine_id: str, json_data: dict, json_content_str: str, json_sufficient: bool) -> str: | |
| if json_sufficient: | |
| return f"--- CONTEXT FOR {machine_id.upper()} (from JSON only) ---\n\n{json_content_str}" | |
| image_folder_path = PROCESSED_IMAGES_PATH / machine_id | |
| all_images = sorted(list(image_folder_path.glob("*.png"))) | |
| images_to_analyze = [] | |
| image_descriptions = json_data.get("image_descriptions", []) | |
| if image_descriptions: | |
| targeted_image_names = _get_targeted_images(query, image_descriptions) | |
| if targeted_image_names: | |
| images_to_analyze = [img for img in all_images if img.name in targeted_image_names] | |
| if not images_to_analyze: | |
| print(" -> [Tier 3] No targeted images selected or found. Analyzing all images as a fallback.") | |
| images_to_analyze = all_images | |
| # Batch process all images to analyze at once | |
| if images_to_analyze: | |
| batch_result = _analyze_images_batch(images_to_analyze, query) | |
| final_context = f"--- CONTEXT FOR {machine_id.upper()} (Full Retrieval) ---\n\n" | |
| final_context += f"## JSON Data:\n{json_content_str}\n\n" | |
| final_context += "## Batch Image Analysis:\n" + batch_result | |
| return final_context | |
| else: | |
| return f"--- CONTEXT FOR {machine_id.upper()} (JSON only, no images found) ---\n\n{json_content_str}" | |
| async def get_machine_context_async(query: str) -> str: | |
| """ | |
| Retrieves all relevant context for the user's query (from JSON, images, etc.). | |
| Returns ONLY the context, not a synthesized answer. | |
| """ | |
| if not OPENROUTER_API_KEY: | |
| return "Error: OPENROUTER_API_KEY is not set." | |
| print("--- Step 1: Building machine map on-the-fly... ---") | |
| machine_map = {} | |
| machine_json_map = {} | |
| machine_json_content_map = {} | |
| try: | |
| for json_file in JSON_INFO_PATH.glob("*.json"): | |
| machine_id = json_file.stem | |
| with open(json_file, 'r') as f: | |
| data = json.load(f) | |
| if "product_name" in data: | |
| machine_map[machine_id] = data["product_name"] | |
| machine_json_map[machine_id] = data | |
| machine_json_content_map[machine_id] = json.dumps(data, indent=2) | |
| except FileNotFoundError: | |
| return "Error: The product_info directory was not found. Please check the KB_BASE_PATH." | |
| except Exception as e: | |
| return f"Error building the machine map: {e}" | |
| if not machine_map: | |
| return "Error: No machine JSON files found or they are empty." | |
| relevant_machine_ids = _get_relevant_machines_from_llm(query, machine_map) | |
| if not relevant_machine_ids: | |
| return "I couldn't identify a specific machine from your query that matches my knowledge base. Could you please rephrase?" | |
| # Batch check JSON sufficiency for all relevant machines | |
| relevant_json_map = {mid: machine_json_map[mid] for mid in relevant_machine_ids} | |
| relevant_json_content_map = {mid: machine_json_content_map[mid] for mid in relevant_machine_ids} | |
| batch_json_sufficiency = _batch_is_json_sufficient(query, relevant_json_map) | |
| print(f"\n-> Retrieving context for {len(relevant_machine_ids)} machine(s)...") | |
| context_blocks = await asyncio.gather( | |
| *[ | |
| _get_context_for_machine_async( | |
| query, | |
| machine_id, | |
| machine_json_map[machine_id], | |
| machine_json_content_map[machine_id], | |
| batch_json_sufficiency.get(machine_id, "NO") == "YES" | |
| ) | |
| for machine_id in relevant_machine_ids | |
| ] | |
| ) | |
| context = "\n\n".join(context_blocks) | |
| return context | |
| async def _get_context_for_machine_async(query: str, machine_id: str, json_data: dict, json_content_str: str, json_sufficient: bool) -> str: | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, _get_context_for_machine, query, machine_id, json_data, json_content_str, json_sufficient) | |
| def get_machine_context(query: str) -> str: | |
| if not OPENROUTER_API_KEY: | |
| return "Error: OPENROUTER_API_KEY is not set." | |
| print("--- Step 1: Building machine map on-the-fly... ---") | |
| machine_map = {} | |
| machine_json_map = {} | |
| machine_json_content_map = {} | |
| try: | |
| for json_file in JSON_INFO_PATH.glob("*.json"): | |
| machine_id = json_file.stem | |
| with open(json_file, 'r') as f: | |
| data = json.load(f) | |
| if "product_name" in data: | |
| machine_map[machine_id] = data["product_name"] | |
| machine_json_map[machine_id] = data | |
| machine_json_content_map[machine_id] = json.dumps(data, indent=2) | |
| except FileNotFoundError: | |
| return "Error: The product_info directory was not found. Please check the KB_BASE_PATH." | |
| except Exception as e: | |
| return f"Error building the machine map: {e}" | |
| if not machine_map: | |
| return "Error: No machine JSON files found or they are empty." | |
| relevant_machine_ids = _get_relevant_machines_from_llm(query, machine_map) | |
| if not relevant_machine_ids: | |
| return "I couldn't identify a specific machine from your query that matches my knowledge base. Could you please rephrase?" | |
| # Batch check JSON sufficiency for all relevant machines | |
| relevant_json_map = {mid: machine_json_map[mid] for mid in relevant_machine_ids} | |
| relevant_json_content_map = {mid: machine_json_content_map[mid] for mid in relevant_machine_ids} | |
| batch_json_sufficiency = _batch_is_json_sufficient(query, relevant_json_map) | |
| print(f"\n-> Retrieving context for {len(relevant_machine_ids)} machine(s)...") | |
| context_blocks = [ | |
| _get_context_for_machine( | |
| query, | |
| machine_id, | |
| machine_json_map[machine_id], | |
| machine_json_content_map[machine_id], | |
| batch_json_sufficiency.get(machine_id, "NO") == "YES" | |
| ) | |
| for machine_id in relevant_machine_ids | |
| ] | |
| context = "\n\n".join(context_blocks) | |
| return context | |
| if __name__ == '__main__': | |
| print("--- Running a direct test of the intelligent_retrieval_tool.py script ---") | |
| sample_query = "What is the price of the Cantek JDT75?" | |
| print(f"Test Query: \"{sample_query}\"") | |
| import asyncio | |
| context = asyncio.run(get_machine_context_async(sample_query)) | |
| print("\n--- FINAL CONTEXT ---") | |
| print(context) |