coldstartmodel_test / preprocess_test.py
datasciencesage's picture
Update preprocess_test.py
4cd4b81 verified
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder,StandardScaler
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
class Model(nn.Module):
def __init__(self, input_shape, num_classes):
super(Model, self).__init__()
self.fc1 = nn.Linear(input_shape, 1024)
self.bn1 = nn.BatchNorm1d(1024)
self.fc2 = nn.Linear(1024, 512)
self.bn2 = nn.BatchNorm1d(512)
self.fc3 = nn.Linear(512, 256)
self.bn3 = nn.BatchNorm1d(256)
self.fc4 = nn.Linear(256, num_classes)
def forward(self, x):
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
x = F.relu(self.bn3(self.fc3(x)))
x = self.fc4(x)
return x
class Preprocess_Test:
def __init__(self,df):
self.df=df
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# self.output_path=output_path
print("INSIDE CLEANING GOT THE DATASET")
import __main__
__main__.Model = Model
def delete_redundant(self,percent):
cols_to_be_deleted=[]
precent=percent/100
for col in self.df.columns:
if self.df[col].isnull().sum()>int(len(self.df)*precent):
cols_to_be_deleted.append(col)
self.df.drop(cols_to_be_deleted,axis=1,inplace=True)
def delete_unncecessary(self):
# Checking for these columns in the dataset
new_cols_list = ['empid', 'hourly_pay', 'job', 'pincode', 'rating']
flag=True
for col in new_cols_list:
if col not in self.df.columns:
flag=False
if flag==False:
new_cols={"EmpID":"empid","PayZone":"hourly_pay","JobFunctionDescription":"job","LocationCode":"pincode","Current Employee Rating":"rating"}
cols=["EmpID","LocationCode","Current Employee Rating","JobFunctionDescription","PayZone"]
for col in self.df.columns:
if col not in cols:
self.df.drop(col,axis=1,inplace=True)
self.df.rename(columns=new_cols,inplace=True)
def preprocess(self,percent=30):
self.delete_redundant(percent=percent)
self.delete_unncecessary()
label_mappings = {}
for col in self.df.select_dtypes(exclude=np.number).columns:
le = LabelEncoder()
self.df[col] = le.fit_transform(self.df[col]) # Transform column
label_mappings[col] = dict(zip(le.classes_, le.transform(le.classes_)))
X=np.array(self.df.drop("empid",axis=1))
Y=np.array(self.df["empid"])
sc=StandardScaler()
self.X_test=sc.fit_transform(X)
le=LabelEncoder()
self.Y_test=le.fit_transform(Y)
def test(self):
print(f"Using device: {self.device}")
# Download the model from Hugging Face
repo_id = "Haliyka/coldstartmodel"
model_file = "model_full.pth" # Matches your upload
local_path = hf_hub_download(repo_id=repo_id, filename=model_file)
# Load the dictionary and extract the model
loaded_data = torch.load(local_path, map_location=self.device, weights_only=False)
if isinstance(loaded_data, dict):
# If it's a dictionary, it might contain state_dict or the model
if "model" in loaded_data:
model_loaded = loaded_data["model"]
else:
model_loaded.load_state_dict(loaded_data)
else:
# If it's not a dictionary, assume it's the state_dict
model_loaded.load_state_dict(loaded_data)
model_loaded.to(self.device)
# model_loaded = loaded_data["model"] # Extract the model from the dictionary
model_loaded.eval() # Set to evaluation mode
print(f"Model loaded from Hugging Face: {repo_id}")
# Convert your data to tensors (assuming X_test, Y_test are defined)
X_test_t = torch.tensor(self.X_test, dtype=torch.float32)
Y_test_t = torch.tensor(self.Y_test, dtype=torch.long)
# Evaluation
BATCH_SIZE = 256
correct = 0
total = 0
all_predictions = []
with torch.no_grad():
for i in range(0, len(X_test_t), BATCH_SIZE):
batch_x = X_test_t[i:i + BATCH_SIZE].to(self.device)
batch_y = Y_test_t[i:i + BATCH_SIZE].to(self.device)
outputs = model_loaded(batch_x)
predicted = torch.argmax(outputs, dim=1)
total += batch_y.size(0)
correct += (predicted == batch_y).sum().item()
all_predictions.extend(predicted.cpu().numpy().tolist())
if i == 0:
print(f"First 10 Test batch results - Predicted: {predicted.cpu().numpy()[:10]}")
print(f"First 10 Test batch results - Actual: {batch_y.cpu().numpy()[:10]}")
return {
"predictions": all_predictions}