Spaces:
Build error
Build error
| import re | |
| from multiprocessing import cpu_count | |
| from keras.src.saving import load_model | |
| import pandas as pd | |
| from keras.src.utils import set_random_seed | |
| from numpy import int64 | |
| from pandarallel import pandarallel | |
| from sklearn.preprocessing import RobustScaler | |
| import gradio as gr | |
| set_random_seed(65536) | |
| pandarallel.initialize(use_memory_fs=True, nb_workers=cpu_count()) | |
| model = load_model('./sqid.keras') | |
| def sql_tokenize(sql_query): | |
| sql_query = sql_query.replace('`', ' ').replace('%20', ' ').replace('=', ' = ').replace('((', ' (( ').replace( | |
| '))', ' )) ').replace('(', ' ( ').replace(')', ' ) ').replace('||', ' || ').replace(',', '').replace( | |
| '--', ' -- ').replace(':', ' : ').replace('%23', ' # ').replace('+', ' + ').replace('!=', | |
| ' != ') \ | |
| .replace('"', ' " ').replace('%26', ' and ').replace('$', ' $ ').replace('%28', ' ( ').replace('%2A', ' * ') \ | |
| .replace('%7C', ' | ').replace('&', ' & ').replace(']', ' ] ').replace('[', ' [ ').replace(';', | |
| ' ; ').replace( | |
| '/*', ' /* ') | |
| sql_reserved = {'SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'ORDER', 'BY', 'GROUP', 'HAVING', | |
| 'LIMIT', 'BETWEEN', 'IS', 'NULL', '%', 'LIKE', 'MIN', 'MAX', 'AS', 'UPPER', 'LOWER', 'TO_DATE', | |
| '=', '>', '<', '>=', '<=', '!=', '<>', 'BETWEEN', 'LIKE', 'EXISTS', 'JOIN', 'UNION', 'ALL', | |
| 'ASC', 'DESC', '||', 'AVG', 'LIMIT', 'EXCEPT', 'INTERSECT', 'CASE', 'WHEN', 'THEN', 'IF', | |
| 'IF', 'ANY', 'CAST', 'CONVERT', 'COALESCE', 'NULLIF', 'INNER', 'OUTER', 'LEFT', 'RIGHT', 'FULL', | |
| 'CROSS', 'OVER', 'PARTITION', 'SUM', 'COUNT', 'WITH', 'INTERVAL', 'WINDOW', 'OVER', | |
| 'ROW_NUMBER', 'RANK', | |
| 'DENSE_RANK', 'NTILE', 'FIRST_VALUE', 'LAST_VALUE', 'LAG', 'LEAD', 'DISTINCT', 'COMMENT', | |
| 'INSERT', | |
| 'UPDATE', 'DELETED', 'MERGE', '*', 'generate_series', 'char', 'chr', 'substr', 'lpad', | |
| 'extract', | |
| 'year', 'month', 'day', 'timestamp', 'number', 'string', 'concat', 'INFORMATION_SCHEMA', | |
| "SQLITE_MASTER", 'TABLES', 'COLUMNS', 'CUBE', 'ROLLUP', 'RECURSIVE', 'FILTER', 'EXCLUDE', | |
| 'AUTOINCREMENT', 'WITHOUT', 'ROWID', 'VIRTUAL', 'INDEXED', 'UNINDEXED', 'SERIAL', | |
| 'DO', 'RETURNING', 'ILIKE', 'ARRAY', 'ANYARRAY', 'JSONB', 'TSQUERY', 'SEQUENCE', | |
| 'SYNONYM', 'CONNECT', 'START', 'LEVEL', 'ROWNUM', 'NOCOPY', 'MINUS', 'AUTO_INCREMENT', 'BINARY', | |
| 'ENUM', 'REPLACE', 'SET', 'SHOW', 'DESCRIBE', 'USE', 'EXPLAIN', 'STORED', 'VIRTUAL', 'RLIKE', | |
| 'MD5', 'SLEEP', 'BENCHMARK', '@@VERSION', 'VERSION', '@VERSION', 'CONVERT', 'NVARCHAR', '#', | |
| '##', 'INJECTX', | |
| 'DELAY', 'WAITFOR', 'RAND', | |
| } | |
| tokens = sql_query.split() | |
| tokens = [re.sub(r"""[^*\w\s.=\-><_|()!"']""", '', token) for token in tokens] | |
| for i, token in enumerate(tokens): | |
| if token.strip().upper() in sql_reserved: | |
| continue | |
| if token.strip().isnumeric(): | |
| tokens[i] = '#NUMBER#' | |
| elif re.match(r'^[a-zA-Z_.|][a-zA-Z0-9_.|]*$', token.strip()): | |
| tokens[i] = '#IDENTIFIER#' | |
| elif re.match(r'^[\d:]*$', token.strip()): | |
| tokens[i] = '#TIMESTAMP#' | |
| elif '%' in token.strip(): | |
| tokens[i] = ' '.join( | |
| [j.strip() if j.strip() in ('%', "'", "'") else '#IDENTIFIER#' for j in token.strip().split('%')]) | |
| return ' '.join(tokens) | |
| def add_features(x): | |
| x['Query'] = x['Query'].copy().parallel_apply(lambda a: sql_tokenize(a)) | |
| x['num_tables'] = x['Query'].str.lower().str.count(r'FROM\s+#IDENTIFIER#', flags=re.I) | |
| x['num_columns'] = x['Query'].str.lower().str.count(r'SELECT\s+#IDENTIFIER#', flags=re.I) | |
| x['num_literals'] = x['Query'].str.lower().str.count("'[^']*'", flags=re.I) + x['Query'].str.lower().str.count( | |
| '"[^"]"', flags=re.I) | |
| x['num_parentheses'] = x['Query'].str.lower().str.count("\\(", flags=re.I) + x['Query'].str.lower().str.count( | |
| '\\)', | |
| flags=re.I) | |
| x['has_union'] = x['Query'].str.lower().str.count(" union |union all", flags=re.I) > 0 | |
| x['has_union'] = x['has_union'].astype(int64) | |
| x['depth_nested_queries'] = x['Query'].str.lower().str.count("\\(", flags=re.I) | |
| x['num_join'] = x['Query'].str.lower().str.count( | |
| " join |inner join|outer join|full outer join|full inner join|cross join|left join|right join", | |
| flags=re.I) | |
| x['num_sp_chars'] = x['Query'].parallel_apply(lambda a: len(re.findall(r'[\'";\-*/%=><|#]', a))) | |
| x['has_mismatched_quotes'] = x['Query'].parallel_apply( | |
| lambda sql_query: 1 if re.search(r"'.*[^']$|\".*[^\"]$", sql_query) else 0) | |
| x['has_tautology'] = x['Query'].parallel_apply(lambda sql_query: 1 if re.search(r"'[\s]*=[\s]*'", sql_query) else 0) | |
| return x | |
| def is_malicious_sql(sql, threshold): | |
| input_df = pd.DataFrame([sql], columns=['Query']) | |
| input_df = add_features(input_df) | |
| numeric_features = ["num_tables", "num_columns", "num_literals", "num_parentheses", "has_union", | |
| "depth_nested_queries", "num_join", "num_sp_chars", "has_mismatched_quotes", "has_tautology"] | |
| scaler = RobustScaler() | |
| x_in = scaler.fit_transform(input_df[numeric_features]) | |
| preds = model.predict([input_df['Query'], x_in]).tolist()[0][0] | |
| if preds > float(threshold): | |
| return f'Malicious - {preds}' | |
| return f'Safe - {preds}' | |
| def respond( | |
| message, | |
| history, | |
| threshold | |
| ): | |
| if len(history) > 5: | |
| history = history[1:] | |
| for val in history: | |
| if val[0].lower().strip() == message.lower().strip(): | |
| return val[1] | |
| val = (message.lower().strip(), is_malicious_sql(message, threshold)) | |
| print(val) | |
| return val[1] | |
| """ | |
| For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
| """ | |
| demo = gr.ChatInterface( | |
| respond, | |
| title='SafeSQL-v1-Demo', | |
| description='Please enter a SQL query as your input. You may adjust the minimum probability threshold for reporting SQLs as malicious using the slider below.', | |
| additional_inputs=[ | |
| gr.Slider(minimum=0.01, maximum=0.99, value=0.75, step=0.01, label="Detection Probability Threshold "), | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |