Spaces:
Sleeping
Sleeping
| from typing import Any, Dict, Tuple, List | |
| from loguru import logger | |
| import json | |
| import re | |
| def _normalize_number(match): | |
| num_str = match.group(0) | |
| if '.' in num_str: | |
| # Normalize float by removing trailing zeros and decimal point if needed | |
| return str(float(num_str)) | |
| return num_str # Leave integers as is | |
| def clean_query(query: str) -> str: | |
| """ | |
| Cleans the MongoDB query string by removing unnecessary whitespace and formatting. | |
| to do: | |
| - replace ' with " | |
| - remove all spaces | |
| - strip the query | |
| - convert '''<query>''' to <query> | |
| - remove \n | |
| - remove empty brackets {} | |
| """ | |
| # replace \' with " | |
| query = query.replace("'", "\"") | |
| # Remove all spaces | |
| query = query.replace(" ", "") | |
| # Strip the query | |
| query = query.strip() | |
| # Convert '''<query>''' to <query> | |
| if query.startswith("'''") and query.endswith("'''"): | |
| query = query[3:-3] | |
| # Remove \n | |
| query = query.replace("\n", "") | |
| # Remove empty brackets {} | |
| query = query.replace("{}", "") | |
| # Replace .toArray() with "" | |
| query = query.replace(".toArray()", "") | |
| # Normalize number strings | |
| query = re.sub(r'(?<!["\w])(-?\d+\.\d+)(?!["\w])', _normalize_number, query) | |
| return query | |
| def extract_field_paths(properties: Dict[str, Any], prefix: str = "") -> Dict[str, str]: | |
| """ | |
| Recursively extract all leaf property names to full dot-paths | |
| from a Mongo JSON Schema 'properties' dict. | |
| Handles nested objects and arrays of objects. | |
| Returns {field_name: full_path} | |
| """ | |
| paths: Dict[str, str] = {} | |
| for key, val in properties.items(): | |
| current = prefix + key | |
| # If nested object, recurse | |
| if val.get("bsonType") == "object" and "properties" in val: | |
| paths.update(extract_field_paths(val["properties"], current + ".")) | |
| # If array of objects, recurse into items | |
| elif val.get("bsonType") == "array" and "items" in val and val["items"].get("bsonType") == "object" and "properties" in val["items"]: | |
| paths.update(extract_field_paths(val["items"]["properties"], current + ".")) | |
| else: | |
| paths[key] = current | |
| return paths | |
| def build_schema_maps(schema: Dict[str, Any]) -> Tuple[Dict[str, str], Dict[str, str]]: | |
| """ | |
| From a full JSON Schema, return two maps: | |
| - input_to_output: field_name -> nested field path | |
| - output_to_input: nested field path -> field_name | |
| Handles both nested and flat schemas correctly. | |
| """ | |
| props = schema["collections"][0]["document"]["properties"] | |
| input_to_output = extract_field_paths(props) | |
| output_to_input = {v: k for k, v in input_to_output.items()} | |
| return input_to_output, output_to_input | |
| def set_nested(d: Dict[str, Any], keys: List[str], value: Any) -> None: | |
| """ | |
| Helper to set a nested value in a dict given a list of keys. | |
| """ | |
| for k in keys[:-1]: | |
| d = d.setdefault(k, {}) | |
| d[keys[-1]] = value | |
| def dot_notation_to_nested(dot: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Convert a dict with dot-notation keys to nested dict structure. | |
| E.g. {"a.b": v} -> {"a": {"b": v}} | |
| """ | |
| out: Dict[str, Any] = {} | |
| for key, val in dot.items(): | |
| parts = key.split('.') | |
| set_nested(out, parts, val) | |
| return out | |
| def nested_to_dot(d: Dict[str, Any], prefix: str = "") -> Dict[str, Any]: | |
| """ | |
| Convert nested dict to dot-notation keys. Treat operator-dicts as leaves. | |
| """ | |
| out: Dict[str, Any] = {} | |
| for k, v in d.items(): | |
| new_pref = f"{prefix}.{k}" if prefix else k | |
| # operator-dict leaf? | |
| if isinstance(v, dict) and v and all(str(kk).startswith("$") for kk in v): | |
| out[new_pref] = v | |
| elif isinstance(v, dict): | |
| out.update(nested_to_dot(v, new_pref)) | |
| else: | |
| out[new_pref] = v | |
| return out | |
| def modified_to_actual_query(modified: Dict[str, Any], | |
| input_to_output: Dict[str, str]) -> Dict[str, Any]: | |
| """ | |
| Convert a flat filter dict (field_name -> value/operator) into | |
| a nested Mongo query dict according to the schema map. | |
| If a key is not in the schema, treat it as dot notation. | |
| """ | |
| query: Dict[str, Any] = {} | |
| for field_name, val in modified.items(): | |
| if field_name in input_to_output: | |
| path = input_to_output[field_name].split('.') | |
| set_nested(query, path, val) | |
| else: | |
| # fallback: treat as dot notation | |
| set_nested(query, field_name.split('.'), val) | |
| return query | |
| def actual_to_modified_query(actual: Dict[str, Any], | |
| output_to_input: Dict[str, str]) -> Dict[str, Any]: | |
| """ | |
| Flatten a nested Mongo query dict back into field_name -> value/operator. | |
| Operator-dicts (keys starting with $) are treated as leaves. | |
| If a path is not in output_to_input mapping, preserve it as-is. | |
| """ | |
| flat: Dict[str, Any] = {} | |
| def recurse(d: Any, prefix: str = "") -> None: | |
| # operator-dict leaf | |
| if isinstance(d, dict) and d and all(k.startswith("$") for k in d): | |
| if prefix in output_to_input: | |
| flat[output_to_input[prefix]] = d | |
| else: | |
| flat[prefix] = d | |
| return | |
| # leaf non-dict | |
| if not isinstance(d, dict): | |
| if prefix in output_to_input: | |
| flat[output_to_input[prefix]] = d | |
| else: | |
| flat[prefix] = d | |
| return | |
| # recurse deeper | |
| for k, v in d.items(): | |
| new_pref = f"{prefix}.{k}" if prefix else k | |
| recurse(v, new_pref) | |
| recurse(actual) | |
| return flat | |
| def build_query_and_options( | |
| modified: Dict[str, Any], | |
| input_to_output: Dict[str, str] | |
| ) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
| """ | |
| From a flat input dict that may include filter fields plus | |
| special options (limit, skip, sort, projection), return: | |
| - nested Mongo filter dict | |
| - options dict with keys: limit, skip, sort, projection | |
| """ | |
| # extract special keys | |
| options: Dict[str, Any] = {} | |
| for opt in ("limit", "skip", "sort", "projection"): # in this order | |
| if opt in modified: | |
| options[opt] = modified.pop(opt) | |
| # build nested filter | |
| query = modified_to_actual_query(modified, input_to_output) | |
| return query, options | |
| def convert_modified_to_actual_code_string( | |
| modified_input: dict, | |
| in2out: dict, | |
| collection_name: str = "events" | |
| ) -> str: | |
| """ | |
| Converts a modified (flat) dict into a MongoDB code string. | |
| Omits the projection argument if opts['projection'] is empty. | |
| Prints filter in dot-notation to match db.find syntax. | |
| """ | |
| import re | |
| # Remove internal metadata fields before processing | |
| modified_input = {k: v for k, v in modified_input.items() if not k.startswith('_')} | |
| filter_dict, opts = build_query_and_options(modified_input.copy(), in2out) | |
| # 1) dot-ify the filter dict | |
| dot_filter = nested_to_dot(filter_dict) | |
| filter_str = json.dumps(dot_filter, separators=(",", ":")) | |
| # 2) Convert date strings back to appropriate MongoDB date format | |
| # This regex matches ISO date strings like "2024-01-01T00:00:00Z" | |
| date_pattern = r'"(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z)"' | |
| # Check if there was a newDate string in the original query | |
| # If so, we need to preserve that format instead of using ISODate | |
| if "newDate" in modified_input.get("_original_query_format", ""): | |
| filter_str = re.sub(date_pattern, r'newDate("\1")', filter_str) | |
| else: | |
| # Default to ISODate format | |
| filter_str = re.sub(date_pattern, r'ISODate("\1")', filter_str) | |
| # 3) Restore special time expressions that might have been converted | |
| time_expr_pattern = r'"(newDate\.getTime\(\)-\d+)"' | |
| filter_str = re.sub(time_expr_pattern, r'\1', filter_str) | |
| # 4) only include projection if non-empty | |
| projection = opts.get("projection", None) | |
| projection_str = "" | |
| if projection: | |
| projection_str = json.dumps(projection, separators=(',', ':')) | |
| # Also convert date strings in projection if any | |
| if "newDate" in modified_input.get("_original_query_format", ""): | |
| projection_str = re.sub(date_pattern, r'newDate("\1")', projection_str) | |
| else: | |
| projection_str = re.sub(date_pattern, r'ISODate("\1")', projection_str) | |
| parts = [f"db.{collection_name}.find({filter_str}" | |
| + (f", {projection_str}" if projection else "") | |
| + ")"] | |
| # 5) chain optional methods | |
| if opts.get("sort"): | |
| # Handle different sort formats | |
| sort_value = opts['sort'] | |
| if isinstance(sort_value, list): | |
| # Convert array format to object format | |
| sort_obj = {} | |
| for key, direction in sort_value: | |
| sort_obj[key] = direction | |
| sort_value = sort_obj | |
| # For sort parameters, we want to preserve the MongoDB format exactly | |
| # Convert the sort object to a string without quotes around the entire thing | |
| if isinstance(sort_value, dict): | |
| sort_items = [] | |
| for k, v in sort_value.items(): | |
| sort_items.append(f'"{k}":{v}') | |
| sort_str = '{' + ','.join(sort_items) + '}' | |
| else: | |
| sort_str = str(sort_value) | |
| parts.append(f".sort({sort_str})") | |
| if opts.get("skip"): | |
| parts.append(f".skip({opts['skip']})") | |
| if opts.get("limit"): | |
| parts.append(f".limit({opts['limit']})") | |
| return "".join(parts) | |
| def convert_actual_code_to_modified_dict(actual_code: str, out2in: dict) -> dict: | |
| """ | |
| Converts an actual MongoDB query string into a modified flat dictionary. | |
| WARNING: This assumes the input is sanitized and safe (e.g., evaluated from a trusted source). | |
| """ | |
| import ast | |
| import re | |
| import json | |
| from datetime import datetime, timedelta | |
| # Store original number strings | |
| original_numbers = {} | |
| def store_number_strings(s: str) -> str: | |
| def replace_number(match): | |
| num_str = match.group(0) | |
| # Only store if it has a decimal point (to preserve trailing zeros) | |
| if '.' in num_str: | |
| try: | |
| num = float(num_str) | |
| # Store the longest representation for this float | |
| key = str(num) | |
| if key not in original_numbers or len(num_str) > len(original_numbers[key]): | |
| original_numbers[key] = num_str | |
| except ValueError: | |
| pass | |
| return num_str | |
| # Match numbers with optional decimal places and trailing zeros | |
| number_pattern = r'-?\d+\.\d+' | |
| re.sub(number_pattern, replace_number, s) | |
| return s | |
| def preprocess_mongo_syntax(query_str): | |
| store_number_strings(query_str) | |
| # Replace ISODate("..."), ISODate('...') with the date string | |
| query_str = re.sub(r'ISODate\("([^"]+)"\)', r'"\1"', query_str) | |
| query_str = re.sub(r"ISODate\('([^']+)'\)", r'"\1"', query_str) | |
| # Handle newDate(newDate().getTime()-<expr>) | |
| def newdate_minus_expr(match): | |
| expr = match.group(1) | |
| try: | |
| # Evaluate the expression safely (only numbers and operators) | |
| ms = int(eval(expr, {"__builtins__": None}, {})) | |
| from datetime import datetime, timedelta | |
| dt = datetime.utcnow() + timedelta(milliseconds=ms) | |
| return '"' + dt.strftime('%Y-%m-%dT%H:%M:%SZ') + '"' | |
| except Exception: | |
| return '"1970-01-01T00:00:00Z"' # fallback | |
| query_str = re.sub(r'newDate\(newDate\(\)\.getTime\(\)([-+*/0-9 ]+)\)', newdate_minus_expr, query_str) | |
| # Replace newDate() with current UTC time | |
| from datetime import datetime | |
| now = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') | |
| query_str = re.sub(r'newDate\(\)', f'"{now}"', query_str) | |
| # Replace newDate(expr) with a string (handle both quote types) | |
| query_str = re.sub(r'newDate\("([^"]+)"\)', r'"\1"', query_str) | |
| query_str = re.sub(r"newDate\('([^']+)'\)", r'"\1"', query_str) | |
| query_str = re.sub(r'newDate\((.*?)\)', r'"\1"', query_str) | |
| # Fix unbalanced brackets | |
| if query_str.count('{') > query_str.count('}'): | |
| query_str += "}" * (query_str.count('{') - query_str.count('}')) | |
| return query_str | |
| # Extract filter dictionary from find() call using regex | |
| def extract_filter_dict(code): | |
| # Match db.collection.find(...) pattern | |
| find_pattern = r'db\.[^.]+\.find\((.*?)(?:\)|,\s*{)' | |
| find_match = re.search(find_pattern, code) | |
| if not find_match: | |
| raise ValueError("Could not extract filter parameters from find() call") | |
| filter_str = find_match.group(1) | |
| # If empty, return empty dict | |
| if not filter_str.strip(): | |
| return {} | |
| try: | |
| # Try parsing as JSON | |
| return json.loads(filter_str) | |
| except json.JSONDecodeError: | |
| # Try with ast.literal_eval | |
| try: | |
| return ast.literal_eval(filter_str) | |
| except: | |
| # Last resort - try fixing common issues and retry | |
| fixed_str = filter_str.replace("'", '"') | |
| try: | |
| return json.loads(fixed_str) | |
| except: | |
| raise ValueError(f"Could not parse filter dictionary: {filter_str}") | |
| # Extract projection dictionary from find() call using regex | |
| def extract_projection_dict(code): | |
| # Match find(..., {projection}) pattern | |
| proj_pattern = r'find\([^{]*({[^}]*})[^{]*,\s*{([^}]*)}\s*\)' | |
| proj_match = re.search(proj_pattern, code) | |
| if not proj_match: | |
| return None | |
| proj_str = proj_match.group(2) | |
| try: | |
| # Try parsing as JSON | |
| return json.loads(proj_str.replace("'", '"')) | |
| except: | |
| # Try with ast.literal_eval | |
| try: | |
| return ast.literal_eval(proj_str) | |
| except: | |
| return None | |
| # Extract method parameters using regex for cases where ast.literal_eval fails | |
| def extract_method_params(code, method_name): | |
| # Look for .method_name({...}) or .method_name([...]) or .method_name(123) pattern | |
| pattern = fr'\.{method_name}\s*\((.*?)\)(?:\.|\s*$)' | |
| match = re.search(pattern, code) | |
| if not match: | |
| return None | |
| param_str = match.group(1).strip() | |
| # Empty parameter | |
| if not param_str: | |
| return None | |
| # Try to handle different parameter types | |
| try: | |
| # Simple number? | |
| if param_str.isdigit(): | |
| return int(param_str) | |
| # JSON object or array? | |
| try: | |
| # Handle MongoDB format with double quotes | |
| return json.loads(param_str.replace("'", '"')) | |
| except json.JSONDecodeError: | |
| # If direct JSON parsing fails, try to use ast.literal_eval | |
| try: | |
| return ast.literal_eval(param_str) | |
| except: | |
| # Return as is if all else fails | |
| return param_str | |
| except Exception as e: | |
| # Return None if all parsing fails | |
| logger.warning(f"Failed to parse parameter for {method_name}: {e}") | |
| return None | |
| # Pre-process the query | |
| preprocessed_code = preprocess_mongo_syntax(actual_code) | |
| try: | |
| # Try to use our more robust regex-based parsing first | |
| filter_dict = extract_filter_dict(preprocessed_code) | |
| projection = extract_projection_dict(preprocessed_code) | |
| # Handle empty projection | |
| options = {"projection": projection} if projection else {} | |
| # Extract sort, limit and skip parameters | |
| sort_param = extract_method_params(preprocessed_code, "sort") | |
| if sort_param is not None: | |
| options["sort"] = sort_param | |
| limit_param = extract_method_params(preprocessed_code, "limit") | |
| if limit_param is not None: | |
| options["limit"] = int(limit_param) if isinstance(limit_param, (int, str)) else limit_param | |
| skip_param = extract_method_params(preprocessed_code, "skip") | |
| if skip_param is not None: | |
| options["skip"] = int(skip_param) if isinstance(skip_param, (int, str)) else skip_param | |
| # Convert actual filter_dict back to modified | |
| flat_filter = actual_to_modified_query(filter_dict, out2in) | |
| # Merge projection, sort, limit into modified if relevant | |
| for key in ("projection", "sort", "skip", "limit"): | |
| if key in options and options[key] is not None: | |
| flat_filter[key] = options[key] | |
| # Add original number strings to the result | |
| flat_filter['_original_numbers'] = original_numbers | |
| return flat_filter | |
| except Exception as e: | |
| # Fall back to traditional AST-based parsing if regex fails | |
| try: | |
| node = ast.parse(preprocessed_code.strip(), mode='eval') | |
| if not isinstance(node.body, ast.Call) or not hasattr(node.body.func, 'attr') or node.body.func.attr != "find": | |
| raise ValueError("Expected .find(...) style query") | |
| # extract find(filter, projection) | |
| args = node.body.args | |
| filter_dict = ast.literal_eval(args[0]) | |
| projection = ast.literal_eval(args[1]) if len(args) > 1 else None | |
| # extract chained methods: sort, skip, limit | |
| options = {"projection": projection} if projection else {} | |
| current = node.body | |
| while isinstance(current, ast.Call): | |
| func = current.func | |
| if hasattr(func, "attr"): | |
| if func.attr == "sort": | |
| options["sort"] = ast.literal_eval(current.args[0]) | |
| elif func.attr == "skip": | |
| options["skip"] = ast.literal_eval(current.args[0]) | |
| elif func.attr == "limit": | |
| options["limit"] = ast.literal_eval(current.args[0]) | |
| current = func.value if hasattr(func, "value") else None | |
| # Convert actual filter_dict back to modified | |
| flat_filter = actual_to_modified_query(filter_dict, out2in) | |
| # Merge projection, sort, limit into modified if relevant | |
| for key in ("projection", "sort", "skip", "limit"): | |
| if key in options: | |
| flat_filter[key] = options[key] | |
| return flat_filter | |
| except Exception as nested_e: | |
| raise ValueError(f"Failed to parse MongoDB query string: {e}. AST fallback also failed: {nested_e}") | |
| # -------------------- Example Usage -------------------- | |
| if __name__ == "__main__": | |
| # Example JSON Schema | |
| schema = { | |
| "collections": [{ | |
| "name": "events", | |
| "document": { | |
| "properties": { | |
| "event_id": {"bsonType": "int"}, | |
| "timestamp": {"bsonType": "int"}, | |
| "severity_level": {"bsonType": "int"}, | |
| "camera_id": {"bsonType": "int"}, | |
| "vehicle_details": {"bsonType": "object", "properties": { | |
| "license_plate_number": {"bsonType": "string"}, | |
| "vehicle_type": {"bsonType": "string"}, | |
| "color": {"bsonType": "string"} | |
| }}, | |
| "person_details": {"bsonType": "object", "properties": { | |
| "match_id": {"bsonType": "int"}, | |
| "age": {"bsonType": "int"}, | |
| "gender": {"bsonType": "string"}, | |
| "clothing_description": {"bsonType": "string"} | |
| }}, | |
| "location": {"bsonType": "object", "properties": { | |
| "latitude": {"bsonType": "double"}, | |
| "longitude": {"bsonType": "double"} | |
| }}, | |
| "sensor_readings": {"bsonType": "object", "properties": { | |
| "temperature": {"bsonType": "double"}, | |
| "speed": {"bsonType": "double"}, | |
| "distance": {"bsonType": "double"} | |
| }}, | |
| "incident_type": {"bsonType": "string"} | |
| } | |
| } | |
| }], | |
| "version": 1 | |
| } | |
| # Build mappings once | |
| in2out, out2in = build_schema_maps(schema) | |
| # Flat user input including filters + options | |
| modified_input = { | |
| "license_plate_number": {"$regex": "^MH12"}, | |
| "timestamp": {"$gte": 1684080000, "$lte": 1684166400}, | |
| "severity_level": 3, | |
| "limit": 50, | |
| "skip": 10, | |
| "sort": [("timestamp", -1)], | |
| "projection": { | |
| "vehicle_details.license_plate_number": 1, | |
| "timestamp": 1, | |
| "_id": 0 | |
| } | |
| } | |
| # Build actual nested query + options | |
| filter_dict, opts = build_query_and_options(modified_input.copy(), in2out) | |
| print("filter_dict =", filter_dict) | |
| print("options =", opts) | |
| # You can then do: | |
| # cursor = ( | |
| # db.events.find(filter_dict, opts.get("projection")) | |
| # .sort(opts.get("sort", [])) | |
| # .skip(opts.get("skip", 0)) | |
| # .limit(opts.get("limit", 0)) | |
| # ) |