Fredaaaaaa commited on
Commit
747baac
·
verified ·
1 Parent(s): 2d943aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -27,27 +27,23 @@ dataset_path = hf_hub_download(repo_id="Fredaaaaaa/hybrid_model", filename="labe
27
  df = pd.read_csv(dataset_path, encoding='ISO-8859-1')
28
  print(f"Dataset loaded successfully! Shape: {df.shape}")
29
 
30
- # Create a set of all unique drugs in the dataset for validation
31
- all_drugs = set()
 
32
 
33
- # Check which columns contain drug names
34
- drug_columns = []
35
- for col in df.columns:
36
- if 'drug' in col.lower() or 'medication' in col.lower():
37
- drug_columns.append(col)
38
- # Add all drugs from this column to our set after cleaning
39
- clean_drugs = df[col].dropna().astype(str).apply(lambda x: x.strip().lower())
40
- all_drugs.update(clean_drugs.unique())
41
 
42
  # Calculate class weights to handle imbalanced classes
43
-
44
-
45
- # Correct the 'classes' parameter to be a numpy.ndarray
46
- class_weights = compute_class_weight('balanced', classes=np.array([0, 1, 2, 3]), y=df['severity'])
47
-
48
  class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
49
  loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
50
 
 
 
 
51
  # Function to properly clean drug names
52
  def clean_drug_name(drug_name):
53
  if not drug_name:
 
27
  df = pd.read_csv(dataset_path, encoding='ISO-8859-1')
28
  print(f"Dataset loaded successfully! Shape: {df.shape}")
29
 
30
+ # Check the columns and display first few rows for debugging
31
+ print(df.columns)
32
+ print(df.head())
33
 
34
+ # Get unique severity classes from the dataset
35
+ unique_classes = df['severity'].unique()
36
+ print(f"Unique severity classes in dataset: {unique_classes}")
 
 
 
 
 
37
 
38
  # Calculate class weights to handle imbalanced classes
39
+ # Use the unique classes from the dataset for the `classes` parameter
40
+ class_weights = compute_class_weight('balanced', classes=np.unique(unique_classes), y=df['severity'])
 
 
 
41
  class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
42
  loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
43
 
44
+ # The rest of your code follows here...
45
+
46
+
47
  # Function to properly clean drug names
48
  def clean_drug_name(drug_name):
49
  if not drug_name: