Spaces:
Paused
Paused
refactor fetch_misclassified_dataframe to improve SQL queries and remove unused imports
Browse files- prepare_pd_df.py +20 -21
prepare_pd_df.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
# prepare_pd_dataframe.py
|
| 2 |
|
| 3 |
-
import os
|
| 4 |
-
import random
|
| 5 |
import pandas as pd
|
| 6 |
-
from sqlalchemy import
|
| 7 |
|
| 8 |
|
| 9 |
def fetch_misclassified_dataframe(label_column: str,
|
|
@@ -34,7 +32,6 @@ def fetch_misclassified_dataframe(label_column: str,
|
|
| 34 |
|
| 35 |
# define conditions based on column
|
| 36 |
miscond = f"mc.correct_{label_column} IS NOT NULL AND mc.model_predicted_{label_column} IS DISTINCT FROM mc.correct_{label_column}"
|
| 37 |
-
corrcond = f"mc.correct_{label_column} IS NOT NULL AND mc.model_predicted_{label_column} = mc.correct_{label_column}"
|
| 38 |
|
| 39 |
# SQL to fetch misclassified records
|
| 40 |
sql_mis = text(f"""
|
|
@@ -56,16 +53,17 @@ def fetch_misclassified_dataframe(label_column: str,
|
|
| 56 |
n_mis = len(df_mis)
|
| 57 |
n_correct = int(n_mis * correct_ratio)
|
| 58 |
|
| 59 |
-
# SQL to fetch correct records
|
| 60 |
sql_corr = text(f"""
|
| 61 |
-
SELECT c.
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
WHERE
|
| 67 |
-
AND {
|
| 68 |
""")
|
|
|
|
| 69 |
with engine.connect() as conn:
|
| 70 |
df_corr_all = pd.read_sql(sql_corr, conn)
|
| 71 |
|
|
@@ -83,12 +81,13 @@ def fetch_misclassified_dataframe(label_column: str,
|
|
| 83 |
|
| 84 |
return df_combined
|
| 85 |
|
| 86 |
-
# If this file is run directly, simple test:
|
| 87 |
-
if __name__ == "__main__":
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
| 1 |
# prepare_pd_dataframe.py
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import pandas as pd
|
| 4 |
+
from sqlalchemy import text
|
| 5 |
|
| 6 |
|
| 7 |
def fetch_misclassified_dataframe(label_column: str,
|
|
|
|
| 32 |
|
| 33 |
# define conditions based on column
|
| 34 |
miscond = f"mc.correct_{label_column} IS NOT NULL AND mc.model_predicted_{label_column} IS DISTINCT FROM mc.correct_{label_column}"
|
|
|
|
| 35 |
|
| 36 |
# SQL to fetch misclassified records
|
| 37 |
sql_mis = text(f"""
|
|
|
|
| 53 |
n_mis = len(df_mis)
|
| 54 |
n_correct = int(n_mis * correct_ratio)
|
| 55 |
|
| 56 |
+
# SQL to fetch correct records from complaints table NOT in misclassified_complaints
|
| 57 |
sql_corr = text(f"""
|
| 58 |
+
SELECT c.id AS complaint_id,
|
| 59 |
+
c.message AS grievance,
|
| 60 |
+
c.department AS department,
|
| 61 |
+
c.urgency AS urgency
|
| 62 |
+
FROM complaints c
|
| 63 |
+
WHERE c.id NOT IN (SELECT complaint_id FROM misclassified_complaints)
|
| 64 |
+
AND c.{label_column} IS NOT NULL
|
| 65 |
""")
|
| 66 |
+
|
| 67 |
with engine.connect() as conn:
|
| 68 |
df_corr_all = pd.read_sql(sql_corr, conn)
|
| 69 |
|
|
|
|
| 81 |
|
| 82 |
return df_combined
|
| 83 |
|
| 84 |
+
# # If this file is run directly, simple test:
|
| 85 |
+
# if __name__ == "__main__":
|
| 86 |
+
# # Quick sanity test for department label
|
| 87 |
+
# df_test = fetch_misclassified_dataframe(label_column="department",
|
| 88 |
+
# correct_ratio=0.5)
|
| 89 |
+
# print("Rows fetched:", len(df_test))
|
| 90 |
+
# print(df_test.head())
|
| 91 |
+
# # Basic assertion: if rows>0 then none of grievances should be null
|
| 92 |
+
# if len(df_test) > 0:
|
| 93 |
+
# assert df_test['grievance'].isna().sum() == 0, "Some grievances are null"
|