zhaospei commited on
Commit
493aea3
·
verified ·
1 Parent(s): 2b9b405

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +75 -3
  2. gitattributes +35 -0
  3. model.safetensors +3 -0
README.md CHANGED
@@ -1,3 +1,75 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🏠 Mô hình Wide & Deep Neural Network - Dự đoán Giá Nhà California
2
+
3
+ ## 📝 Mô tả
4
+
5
+ Đây là một mô hình **Wide & Deep Neural Network** được huấn luyện trên tập dữ liệu **California Housing** để dự đoán giá nhà trung bình (`MedHouseVal`). Mô hình được xây dựng bằng **PyTorch**, dựa trên kiến trúc trong cuốn *Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow* của **Aurélien Géron**.
6
+
7
+ ## 📌 Nhiệm vụ
8
+
9
+ Dự đoán giá nhà dựa trên dữ liệu bảng (tabular regression) với 8 đặc trưng đầu vào.
10
+
11
+ ## 📥 Đầu vào
12
+
13
+ - **Số chiều**: `[batch_size, 8]`
14
+ - **Kiểu dữ liệu**: `torch.FloatTensor`
15
+ - **Các đặc trưng đầu vào**:
16
+ - `'MedInc'` – Thu nhập trung vị
17
+ - `'HouseAge'` – Tuổi trung bình của căn nhà
18
+ - `'AveRooms'` – Số phòng trung bình
19
+ - `'AveBedrms'` – Số phòng ngủ trung bình
20
+ - `'Population'` – Dân số
21
+ - `'AveOccup'` – Số người trung bình trên mỗi hộ
22
+ - `'Latitude'` – Vĩ độ
23
+ - `'Longitude'` – Kinh độ
24
+
25
+ ## 📤 Đầu ra
26
+
27
+ - **Kiểu**: `torch.FloatTensor` có shape `[batch_size, 1]`
28
+ - **Ý nghĩa**: Giá nhà trung bình dự đoán (giá trị thực).
29
+
30
+ ## 🧪 Cách sử dụng mô hình
31
+
32
+ Dưới đây là ví dụ về cách sử dụng mô hình với dữ liệu đầu vào giả lập:
33
+
34
+ ```python
35
+ import torch
36
+ import torch.nn as nn
37
+ from huggingface_hub import PyTorchModelHubMixin
38
+
39
+ # Tạo dữ liệu đầu vào giả lập (batch 1, 8 features)
40
+ x_input = torch.randn(1, 8)
41
+ print("Mock input:")
42
+ print(x_input)
43
+
44
+ # Định nghĩa mô hình Wide & Deep Neural Network
45
+ class WideAndDeepNet(nn.Module, PyTorchModelHubMixin):
46
+ def __init__(self):
47
+ super().__init__()
48
+ self.hidden1 = nn.Linear(6, 30)
49
+ self.hidden2 = nn.Linear(30, 30)
50
+ self.main_head = nn.Linear(35, 1)
51
+ self.aux_head = nn.Linear(30, 1)
52
+ self.main_loss_fn = nn.MSELoss(reduction='sum')
53
+ self.aux_loss_fn = nn.MSELoss(reduction='sum')
54
+
55
+ def forward(self, input_wide, input_deep, label=None):
56
+ act = torch.relu(self.hidden1(input_deep))
57
+ act = torch.relu(self.hidden2(act))
58
+ concat = torch.cat([input_wide, act], dim=1)
59
+ main_output = self.main_head(concat)
60
+ aux_output = self.aux_head(act)
61
+ if label is not None:
62
+ main_loss = self.main_loss_fn(main_output.squeeze(), label)
63
+ aux_loss = self.aux_loss_fn(aux_output.squeeze(), label)
64
+ return WideAndDeepNetOutput(main_output=main_output, aux_output=aux_output)
65
+
66
+ # Tải mô hình từ Hugging Face Hub
67
+ model = WideAndDeepNet.from_pretrained("sadhaklal/wide-and-deep-net-california-housing-v3")
68
+ model.eval()
69
+
70
+ # Dự đoán với mô hình
71
+ with torch.no_grad():
72
+ prediction = model(x_input)
73
+
74
+ print(f"Giá nhà dự đoán (mock input): {prediction.item():.3f}")
75
+ ```
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78354917b85970dac203b2de54d5e84e1ff4e62aa45b4a80052d9b532a7fe049
3
+ size 5396