RAYAuser commited on
Commit
d2c7992
·
verified ·
1 Parent(s): a1b4245

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +91 -0
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: image-classification
4
+ ---
5
+ MNIST Classification Comparison Project
6
+ This repository hosts two classification models trained on the MNIST dataset, representing handwritten digits. The project's goal is to compare the performance of a deep learning model with a traditional machine learning model.
7
+
8
+ Included Models
9
+ The repository contains the following two models, saved after a parallel training process:
10
+
11
+ Simple Neural Network (PyTorch): A Feed-Forward Neural Network (FFNN) model suitable for low-resolution image classification tasks. It was trained to recognize digits from 0 to 9.
12
+
13
+ Random Forest (Scikit-learn): An ensemble model from the decision tree family, known for its robustness and high efficiency on structured data.
14
+
15
+ Usage Instructions
16
+ To use the models, please follow the steps below.
17
+
18
+ Prerequisites
19
+ The following Python libraries are required to load and run the models:
20
+
21
+ pip install torch scikit-learn joblib huggingface-hub
22
+
23
+
24
+ Loading the Models
25
+ The following script allows you to download the model files from this repository and load them into memory for later use.
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ from joblib import load
30
+ from huggingface_hub import hf_hub_download
31
+
32
+ # Definition of the Neural Network class for loading.
33
+ class SimpleFFNN(nn.Module):
34
+ def __init__(self):
35
+ super(SimpleFFNN, self).__init__()
36
+ self.fc1 = nn.Linear(28 * 28, 128)
37
+ self.relu = nn.ReLU()
38
+ self.fc2 = nn.Linear(128, 10)
39
+
40
+ def forward(self, x):
41
+ x = x.view(-1, 28 * 28)
42
+ x = self.fc1(x)
43
+ x = self.relu(x)
44
+ x = self.fc2(x)
45
+ return x
46
+
47
+ # Hugging Face repository ID
48
+ repo_id = "RAYAuser/ratron-minst-2tech"
49
+
50
+ # Downloading the files
51
+ ffnn_path = hf_hub_download(repo_id=repo_id, filename="ffnn_model_state.pt")
52
+ rf_path = hf_hub_download(repo_id=repo_id, filename="random_forest_model.joblib")
53
+
54
+ # Loading the models
55
+ ffnn_model_loaded = SimpleFFNN()
56
+ ffnn_model_loaded.load_state_dict(torch.load(ffnn_path))
57
+ ffnn_model_loaded.eval()
58
+
59
+ rf_model_loaded = load(rf_path)
60
+
61
+ print("The models have been loaded successfully.")
62
+
63
+ Prediction on New Data
64
+ Once the models are loaded, you can use them to make inferences on new images.
65
+
66
+ import numpy as np
67
+
68
+ # Creating a numpy array to simulate an image (28x28 pixels)
69
+ sample_image = np.random.rand(28, 28)
70
+
71
+ # Reshaping the input data for the models
72
+ sample_image_flat = sample_image.reshape(1, -1)
73
+ ffnn_input_tensor = torch.from_numpy(sample_image_flat).float()
74
+
75
+ # Prediction with the Neural Network
76
+ with torch.no_grad():
77
+ output = ffnn_model_loaded(ffnn_input_tensor)
78
+ _, ffnn_prediction = torch.max(output.data, 1)
79
+
80
+ # Prediction with the Random Forest
81
+ rf_prediction = rf_model_loaded.predict(sample_image_flat)
82
+
83
+ print(f"Prediction by the Neural Network: {ffnn_prediction.item()}")
84
+ print(f"Prediction by the Random Forest: {rf_prediction[0]}")
85
+
86
+ Notes
87
+ The SimpleFFNN class must be defined to allow the PyTorch model to be loaded.
88
+
89
+ A warning regarding Scikit-learn version incompatibility may appear if the version used for training is not identical to the one in your environment. This is generally non-critical.
90
+
91
+ RAY AUTRA TECHNOLOGY 2025