File size: 2,567 Bytes
d1bfee5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | # CycleGAN Image Style Transfer (Horse to Zebra)
This project implements an end-to-end CycleGAN model for unpaired image style transfer, specifically focused on the **Horse to Zebra** dataset.
## Project Structure
### TensorFlow Version (Recommended for this system)
- `tf_dataset.py`: TensorFlow Data loader.
- `tf_models.py`: Keras/TF CycleGAN architectures.
- `tf_train.py`: TensorFlow training script.
- `tf_predict.py`: TensorFlow inference script.
### PyTorch Version
- `dataset.py`: PyTorch Dataset class.
- `models.py`: PyTorch Generator and Discriminator.
- `train.py`: PyTorch training script.
- `predict.py`: PyTorch inference script.
- `download_data.py`: Script to download and extract the Horse2Zebra dataset.
- `requirements.txt`: Project dependencies.
## Setup
1. Install dependencies:
```bash
pip install -r requirements.txt
```
2. Download the dataset:
```bash
python download_data.py
```
## Training
### TensorFlow
```bash
python tf_train.py
```
### PyTorch
```bash
python train.py
```
Checkpoints and sample results will be saved in the `saved_images/` directory or as `.h5` files.
## Inference
### TensorFlow
```bash
python tf_predict.py
```
### PyTorch
```bash
python predict.py
```
The result will be saved as `tf_prediction.png` or `prediction_result.png`.
## Web Application
A premium web interface is included for easy interaction with the models.
### Features
- **Bidirectional Style Transfer**: Switch between Horse ➔ Zebra and Zebra ➔ Horse.
- **Glassmorphic UI**: Modern, responsive design with smooth animations.
- **Real-time Preview**: See your uploaded image and stylized result side-by-side.
- **One-click Download**: Save your stylized art instantly.
### Running the App
1. Start the Flask server:
```bash
python app.py
```
2. Open your browser and go to `http://localhost:5000`.
## Notes
- The model uses **PatchGAN** for the discriminator and a **ResNet-based generator** with 9 residual blocks for 256x256 images.
- Training is optimized for both GPU and CPU.
- The identity loss is currently set to 0 to speed up training, but can be adjusted in the training scripts (LAMBDA_IDENTITY or through `identity_loss`).
## Troubleshooting
- **PyTorch DLL Error (WinError 1114)**: If you encounter this on Windows, it is often related to GPU driver conflicts or power management. It is recommended to use the **TensorFlow version** provided in this repository as it is confirmed to be stable in this environment.
|