File size: 1,628 Bytes
f145bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# 机器学习 - 猫狗识别简单实现

这是一个简单的基于 Tensorflow 的猫狗识别模型。用于入门机器学习以及了解机器学习中的基础概念。

## 安装

首先需要下载猫狗识别的数据集用于训练。可以参考在 HuggingFace 的 [Microsoft 的数据集](https://huggingface.co/datasets/microsoft/cats_vs_dogs)
下载训练数据集后,需要做一些整理,先确保项目目录结构如下:

```
train/
├── cats/
│   ├── cat.0.jpg
│   ├── ....
|── dogs/
    ├── dog.0.jpg
    ├── ....
```

然后复制本项目中的 `.env.example` 一份到本目录,重命名为 `.env` 并修改内容,改成 train 目录的所在地

```bash
TRAIN_DATASET = /path/to/your/dataset
```

然后在 conda 环境下安装依赖。

```bash
conda create --name cat-vs-dog python=3.9 --file requirements.txt
```

## 训练

执行命令

```bash
python train.py
```

将会在本目录下生成 `cat_dog_model.h5` 模型文件。

## 预测

准备好测试数据集的目录,目录下只需存放图片即可。

在 `.env` 文件中添加测试数据集的目录。

```bash
TEST_DATASET = /path/to/your/dataset
```

请先确保你已经训练好了数据并正确生成了 `cat_dog_model.h5` 模型文件。

如果你没条件训练,可以直接在本仓库的 Release 上[下载现有的模型](https://github.com/PrinOrange/dog-vs-cat/releases/tag/1.0.0)。


然后执行命令

```bash
python test.py
```

在命令台中会产生如下输出:

```plaintext
The model predicts the image xxx.jpg is a cat, with sigmoid 0.9999812
...
```