mr-kush commited on
Commit
a546051
·
1 Parent(s): 147f8a4

refactor fetch_misclassified_dataframe to improve SQL queries and remove unused imports

Browse files
Files changed (1) hide show
  1. prepare_pd_df.py +20 -21
prepare_pd_df.py CHANGED
@@ -1,9 +1,7 @@
1
  # prepare_pd_dataframe.py
2
 
3
- import os
4
- import random
5
  import pandas as pd
6
- from sqlalchemy import create_engine, text
7
 
8
 
9
  def fetch_misclassified_dataframe(label_column: str,
@@ -34,7 +32,6 @@ def fetch_misclassified_dataframe(label_column: str,
34
 
35
  # define conditions based on column
36
  miscond = f"mc.correct_{label_column} IS NOT NULL AND mc.model_predicted_{label_column} IS DISTINCT FROM mc.correct_{label_column}"
37
- corrcond = f"mc.correct_{label_column} IS NOT NULL AND mc.model_predicted_{label_column} = mc.correct_{label_column}"
38
 
39
  # SQL to fetch misclassified records
40
  sql_mis = text(f"""
@@ -56,16 +53,17 @@ def fetch_misclassified_dataframe(label_column: str,
56
  n_mis = len(df_mis)
57
  n_correct = int(n_mis * correct_ratio)
58
 
59
- # SQL to fetch correct records
60
  sql_corr = text(f"""
61
- SELECT c.message AS grievance,
62
- mc.correct_department AS department,
63
- mc.correct_urgency AS urgency
64
- FROM misclassified_complaints mc
65
- JOIN complaints c ON c.id = mc.complaint_id
66
- WHERE mc.reviewed = TRUE
67
- AND {corrcond}
68
  """)
 
69
  with engine.connect() as conn:
70
  df_corr_all = pd.read_sql(sql_corr, conn)
71
 
@@ -83,12 +81,13 @@ def fetch_misclassified_dataframe(label_column: str,
83
 
84
  return df_combined
85
 
86
- # If this file is run directly, simple test:
87
- if __name__ == "__main__":
88
- # Quick sanity test for department label
89
- df_test = fetch_misclassified_dataframe(label_column="department", correct_ratio=0.5)
90
- print("Rows fetched:", len(df_test))
91
- print(df_test.head())
92
- # Basic assertion: if rows>0 then none of grievances should be null
93
- if len(df_test) > 0:
94
- assert df_test['grievance'].isna().sum() == 0, "Some grievances are null"
 
 
1
  # prepare_pd_dataframe.py
2
 
 
 
3
  import pandas as pd
4
+ from sqlalchemy import text
5
 
6
 
7
  def fetch_misclassified_dataframe(label_column: str,
 
32
 
33
  # define conditions based on column
34
  miscond = f"mc.correct_{label_column} IS NOT NULL AND mc.model_predicted_{label_column} IS DISTINCT FROM mc.correct_{label_column}"
 
35
 
36
  # SQL to fetch misclassified records
37
  sql_mis = text(f"""
 
53
  n_mis = len(df_mis)
54
  n_correct = int(n_mis * correct_ratio)
55
 
56
+ # SQL to fetch correct records from complaints table NOT in misclassified_complaints
57
  sql_corr = text(f"""
58
+ SELECT c.id AS complaint_id,
59
+ c.message AS grievance,
60
+ c.department AS department,
61
+ c.urgency AS urgency
62
+ FROM complaints c
63
+ WHERE c.id NOT IN (SELECT complaint_id FROM misclassified_complaints)
64
+ AND c.{label_column} IS NOT NULL
65
  """)
66
+
67
  with engine.connect() as conn:
68
  df_corr_all = pd.read_sql(sql_corr, conn)
69
 
 
81
 
82
  return df_combined
83
 
84
+ # # If this file is run directly, simple test:
85
+ # if __name__ == "__main__":
86
+ # # Quick sanity test for department label
87
+ # df_test = fetch_misclassified_dataframe(label_column="department",
88
+ # correct_ratio=0.5)
89
+ # print("Rows fetched:", len(df_test))
90
+ # print(df_test.head())
91
+ # # Basic assertion: if rows>0 then none of grievances should be null
92
+ # if len(df_test) > 0:
93
+ # assert df_test['grievance'].isna().sum() == 0, "Some grievances are null"