Upload folder using huggingface_hub
Browse files- .dockerignore +38 -0
- .env.example +11 -0
- .gitignore +7 -0
- .python-version +1 -0
- Dockerfile +24 -0
- README.md +371 -0
- app/__init__.py +7 -0
- app/api/__init__.py +1 -0
- app/api/controllers.py +36 -0
- app/api/models.py +109 -0
- app/api/routes/__init__.py +1 -0
- app/api/routes/prediction.py +23 -0
- app/core/__init__.py +1 -0
- app/core/app.py +80 -0
- app/core/dependencies.py +17 -0
- app/core/logging.py +26 -0
- app/services/__init__.py +1 -0
- app/services/base.py +30 -0
- app/services/inference.py +92 -0
- app/services/inference_clip.py +171 -0
- app/services/inference_dino.py +177 -0
- cat.json +0 -0
- challenge-cli.py +310 -0
- delete_corrupted_files.py +55 -0
- main.py +6 -0
- makefile +58 -0
- mask.png +0 -0
- prompt.sh +6 -0
- requirements copy.txt +26 -0
- requirements.cpu.in +1 -0
- requirements.cpu.txt +561 -0
- requirements.in +34 -0
- requirements.torch.cpu.in +3 -0
- requirements.torch.cpu.txt +9 -0
- requirements.torch.gpu.in +2 -0
- requirements.torch.gpu.txt +71 -0
- requirements.txt +530 -0
- response.json +1 -0
- scripts/generate_test_datasets.py +411 -0
- scripts/model_download.bash +10 -0
- scripts/test_datasets.py +382 -0
- submit.sh +2 -0
- test_main.http +12 -0
- upload_model.py +93 -0
- upload_submission.sh +5 -0
- utils/My_version.py +207 -0
- utils/losses.py +440 -0
.dockerignore
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.Python
|
| 6 |
+
*.so
|
| 7 |
+
*.egg
|
| 8 |
+
*.egg-info
|
| 9 |
+
dist
|
| 10 |
+
build
|
| 11 |
+
|
| 12 |
+
.venv
|
| 13 |
+
venv
|
| 14 |
+
ENV
|
| 15 |
+
env
|
| 16 |
+
|
| 17 |
+
.git
|
| 18 |
+
.gitignore
|
| 19 |
+
.idea
|
| 20 |
+
.vscode
|
| 21 |
+
.claude
|
| 22 |
+
|
| 23 |
+
*.md
|
| 24 |
+
README.md
|
| 25 |
+
Dockerfile
|
| 26 |
+
.dockerignore
|
| 27 |
+
|
| 28 |
+
test_*.http
|
| 29 |
+
test_results
|
| 30 |
+
scripts/test_datasets
|
| 31 |
+
|
| 32 |
+
.pytest_cache
|
| 33 |
+
.coverage
|
| 34 |
+
htmlcov
|
| 35 |
+
|
| 36 |
+
*.log
|
| 37 |
+
.DS_Store
|
| 38 |
+
.python-version
|
.env.example
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# App Configuration
|
| 2 |
+
APP_NAME="ML Inference Service"
|
| 3 |
+
APP_VERSION="0.1.0"
|
| 4 |
+
DEBUG=false
|
| 5 |
+
|
| 6 |
+
# Server
|
| 7 |
+
HOST="0.0.0.0"
|
| 8 |
+
PORT=8000
|
| 9 |
+
|
| 10 |
+
# Model
|
| 11 |
+
MODEL_NAME="microsoft/resnet-18"
|
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dyff-outputs/
|
| 2 |
+
models/
|
| 3 |
+
test_datasets/
|
| 4 |
+
test_results/
|
| 5 |
+
venv/
|
| 6 |
+
**/__pycache__
|
| 7 |
+
*.tmp
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12.11
|
Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE="1" \
|
| 4 |
+
PYTHONUNBUFFERED="1"
|
| 5 |
+
|
| 6 |
+
# hadolint ignore=DL3013
|
| 7 |
+
RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel
|
| 8 |
+
|
| 9 |
+
WORKDIR /app/
|
| 10 |
+
|
| 11 |
+
COPY requirements.cpu.txt ./
|
| 12 |
+
RUN python3 -m pip install --no-cache-dir -r ./requirements.cpu.txt
|
| 13 |
+
|
| 14 |
+
COPY requirements.torch.cpu.txt ./
|
| 15 |
+
RUN python3 -m pip install --no-cache-dir -r ./requirements.torch.cpu.txt
|
| 16 |
+
|
| 17 |
+
COPY app ./app
|
| 18 |
+
COPY models ./models
|
| 19 |
+
COPY main.py ./
|
| 20 |
+
|
| 21 |
+
EXPOSE 8000
|
| 22 |
+
|
| 23 |
+
# ENTRYPOINT ["python3", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
| 24 |
+
ENTRYPOINT ["python3", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# ML Inference Service
|
| 6 |
+
|
| 7 |
+
FastAPI service for serving ML models over HTTP. Comes with ResNet-18 for image classification out of the box, but you can swap in any model you want.
|
| 8 |
+
|
| 9 |
+
## Quick Start
|
| 10 |
+
|
| 11 |
+
**Install `uv`:**
|
| 12 |
+
https://docs.astral.sh/uv/getting-started/installation/
|
| 13 |
+
|
| 14 |
+
**Local development:**
|
| 15 |
+
```bash
|
| 16 |
+
# Install dependencies
|
| 17 |
+
make setup
|
| 18 |
+
source venv/bin/activate
|
| 19 |
+
|
| 20 |
+
# Download the example model
|
| 21 |
+
make download
|
| 22 |
+
|
| 23 |
+
# Run it
|
| 24 |
+
make serve
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
In a second terminal:
|
| 28 |
+
```bash
|
| 29 |
+
# Process an example input
|
| 30 |
+
./prompt.sh cat.json
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
Server runs on `http://127.0.0.1:8000`. Check `/docs` for the interactive API documentation.
|
| 34 |
+
|
| 35 |
+
**Docker:**
|
| 36 |
+
```bash
|
| 37 |
+
# Build
|
| 38 |
+
make docker-build
|
| 39 |
+
|
| 40 |
+
# Run
|
| 41 |
+
make docker-run
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Testing the API
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
# Using curl
|
| 48 |
+
curl -X POST http://localhost:8000/predict \
|
| 49 |
+
-H "Content-Type: application/json" \
|
| 50 |
+
-d '{
|
| 51 |
+
"image": {
|
| 52 |
+
"mediaType": "image/jpeg",
|
| 53 |
+
"data": "<base64-encoded-image>"
|
| 54 |
+
}
|
| 55 |
+
}'
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Example response:
|
| 59 |
+
```json
|
| 60 |
+
{
|
| 61 |
+
"logprobs": [-0.859380304813385,-1.2701971530914307,-2.1918208599090576,-1.69235098361969],
|
| 62 |
+
"localizationMask": {
|
| 63 |
+
"mediaType":"image/png",
|
| 64 |
+
"data":"iVBORw0KGgoAAAANSUhEUgAAA8AAAAKDAQAAAAD9Fl5AAAAAu0lEQVR4nO3NsREAMAgDMWD/nZMVKEwn1T5/FQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMCl3g5f+HC24TRhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAj70gwKsTlmdBwAAAABJRU5ErkJggg=="
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Project Structure
|
| 70 |
+
|
| 71 |
+
```
|
| 72 |
+
example-submission/
|
| 73 |
+
├── main.py # Entry point
|
| 74 |
+
├── app/
|
| 75 |
+
│ ├── core/
|
| 76 |
+
│ │ ├── app.py # <= INSTANTIATE YOUR DETECTOR HERE
|
| 77 |
+
│ │ └── logging.py # Logging setup
|
| 78 |
+
│ ├── api/
|
| 79 |
+
│ │ ├── models.py # Request/response schemas
|
| 80 |
+
│ │ ├── controllers.py # Business logic
|
| 81 |
+
│ │ └── routes/
|
| 82 |
+
│ │ └── prediction.py # POST /predict
|
| 83 |
+
│ └── services/
|
| 84 |
+
│ ├── base.py # <= YOUR DETECTOR IMPLEMENTS THIS INTERFACE
|
| 85 |
+
│ └── inference.py # Example service based on ResNet-18
|
| 86 |
+
├── models/
|
| 87 |
+
│ └── microsoft/
|
| 88 |
+
│ └── resnet-18/ # Model weights and config
|
| 89 |
+
├── scripts/
|
| 90 |
+
│ ├── model_download.bash
|
| 91 |
+
│ ├── generate_test_datasets.py
|
| 92 |
+
│ └── test_datasets.py
|
| 93 |
+
├── Dockerfile
|
| 94 |
+
├── .env.example # Environment config template
|
| 95 |
+
├── cat.json # An example /predict request object
|
| 96 |
+
├── makefile
|
| 97 |
+
├── prompt.sh # Script that makes a /predict request
|
| 98 |
+
├── requirements.in
|
| 99 |
+
├── requirements.txt
|
| 100 |
+
├── response.json # An example /predict response object
|
| 101 |
+
└──
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## How to Plug In Your Own Model
|
| 105 |
+
|
| 106 |
+
To integrate your model, implement the `InferenceService` abstract class defined in `app/services/base.py`. You can follow the example implementation in `app/services/inference.py`, which is based on ResNet-18. After implementing the required interface, instantiate your model in the `lifespan()` function in `app/core/app.py`, replacing the `ResNetInferenceService` instance.
|
| 107 |
+
|
| 108 |
+
### Step 1: Create Your Service Class
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
# app/services/your_model_service.py
|
| 112 |
+
from app.services.base import InferenceService
|
| 113 |
+
from app.api.models import ImageRequest, PredictionResponse
|
| 114 |
+
|
| 115 |
+
class YourModelService(InferenceService[ImageRequest, PredictionResponse]):
|
| 116 |
+
def __init__(self, model_name: str):
|
| 117 |
+
self.model_name = model_name
|
| 118 |
+
self.model_path = f"models/{model_name}"
|
| 119 |
+
self.model = None
|
| 120 |
+
self._is_loaded = False
|
| 121 |
+
|
| 122 |
+
def load_model(self) -> None:
|
| 123 |
+
"""Load your model here. Called once at startup."""
|
| 124 |
+
self.model = load_your_model(self.model_path)
|
| 125 |
+
self._is_loaded = True
|
| 126 |
+
|
| 127 |
+
def predict(self, request: ImageRequest) -> PredictionResponse:
|
| 128 |
+
"""Actual inference happens here."""
|
| 129 |
+
image = decode_base64_image(request.image.data)
|
| 130 |
+
result = self.model(image)
|
| 131 |
+
|
| 132 |
+
logprobs = ...
|
| 133 |
+
mask = ...
|
| 134 |
+
|
| 135 |
+
return PredictionResponse(
|
| 136 |
+
logprobs=logprobs,
|
| 137 |
+
localizationMask=mask,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def is_loaded(self) -> bool:
|
| 142 |
+
return self._is_loaded
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Step 2: Register Your Service
|
| 146 |
+
|
| 147 |
+
Open `app/core/app.py` and find the lifespan function:
|
| 148 |
+
|
| 149 |
+
```python
|
| 150 |
+
# Change this line:
|
| 151 |
+
service = ResNetInferenceService(model_name="microsoft/resnet-18")
|
| 152 |
+
|
| 153 |
+
# To this:
|
| 154 |
+
service = YourModelService(...)
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
That's it. The `/predict` endpoint now serves your model.
|
| 158 |
+
|
| 159 |
+
### Model Files
|
| 160 |
+
|
| 161 |
+
Put your model files under the `models/` directory:
|
| 162 |
+
|
| 163 |
+
```
|
| 164 |
+
models/
|
| 165 |
+
└── your-org/
|
| 166 |
+
└── your-model/
|
| 167 |
+
├── config.json
|
| 168 |
+
├── weights.bin
|
| 169 |
+
└── (other files)
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
## Configuration
|
| 173 |
+
|
| 174 |
+
Settings are managed via environment variables or a `.env` file. See `.env.example` for all available options.
|
| 175 |
+
|
| 176 |
+
**Default values:**
|
| 177 |
+
- `APP_NAME`: "ML Inference Service"
|
| 178 |
+
- `APP_VERSION`: "0.1.0"
|
| 179 |
+
- `DEBUG`: false
|
| 180 |
+
- `HOST`: "0.0.0.0"
|
| 181 |
+
- `PORT`: 8000
|
| 182 |
+
- `MODEL_NAME`: "microsoft/resnet-18"
|
| 183 |
+
|
| 184 |
+
**To customize:**
|
| 185 |
+
```bash
|
| 186 |
+
# Copy the example
|
| 187 |
+
cp .env.example .env
|
| 188 |
+
|
| 189 |
+
# Edit values
|
| 190 |
+
vim .env
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
Or set environment variables directly:
|
| 194 |
+
```bash
|
| 195 |
+
export MODEL_NAME="google/vit-base-patch16-224"
|
| 196 |
+
uvicorn main:app --reload
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
## Deployment
|
| 200 |
+
|
| 201 |
+
**Development:**
|
| 202 |
+
```bash
|
| 203 |
+
uvicorn main:app --reload
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
**Production:**
|
| 207 |
+
```bash
|
| 208 |
+
gunicorn main:app -w 4 -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
The service runs on CPU by default. For GPU inference, install CUDA-enabled PyTorch and modify your service to move tensors to the GPU device.
|
| 212 |
+
|
| 213 |
+
**Docker:**
|
| 214 |
+
- Multi-stage build keeps the image small
|
| 215 |
+
- Runs as non-root user (`appuser`)
|
| 216 |
+
- Python dependencies installed in user site-packages
|
| 217 |
+
- Model files baked into the image
|
| 218 |
+
|
| 219 |
+
## What Happens When You Start the Server
|
| 220 |
+
|
| 221 |
+
```
|
| 222 |
+
INFO: Starting ML Inference Service...
|
| 223 |
+
INFO: Initializing ResNet service: models/microsoft/resnet-18
|
| 224 |
+
INFO: Loading model from models/microsoft/resnet-18
|
| 225 |
+
INFO: Model loaded: 1000 classes
|
| 226 |
+
INFO: Startup completed successfully
|
| 227 |
+
INFO: Uvicorn running on http://0.0.0.0:8000
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
If you see "Model directory not found", check that your model files exist at the expected path with the full org/model structure.
|
| 231 |
+
|
| 232 |
+
## API Reference
|
| 233 |
+
|
| 234 |
+
**Endpoint:** `POST /predict`
|
| 235 |
+
|
| 236 |
+
**Request:**
|
| 237 |
+
```json
|
| 238 |
+
{
|
| 239 |
+
"image": {
|
| 240 |
+
"mediaType": "image/jpeg", // or "image/png"
|
| 241 |
+
"data": "<base64 string>"
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
```
|
| 245 |
+
|
| 246 |
+
**Response:**
|
| 247 |
+
```json
|
| 248 |
+
{
|
| 249 |
+
"logprobs": [float], // Log-probabilities of each label
|
| 250 |
+
"localizationMask": { // [Optional] binary mask
|
| 251 |
+
"mediaType": "image/png", // Always png
|
| 252 |
+
"data": "<base64 string>" // Image data
|
| 253 |
+
}
|
| 254 |
+
}
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
**Docs:**
|
| 258 |
+
- Swagger UI: `http://localhost:8000/docs`
|
| 259 |
+
- ReDoc: `http://localhost:8000/redoc`
|
| 260 |
+
- OpenAPI JSON: `http://localhost:8000/openapi.json`
|
| 261 |
+
|
| 262 |
+
## PyArrow Test Datasets
|
| 263 |
+
|
| 264 |
+
We've included a test dataset system for validating your model. It generates 100 standardized test cases covering normal inputs, edge cases, performance benchmarks, and model comparisons.
|
| 265 |
+
|
| 266 |
+
### Generate Datasets
|
| 267 |
+
|
| 268 |
+
```bash
|
| 269 |
+
python scripts/generate_test_datasets.py
|
| 270 |
+
```
|
| 271 |
+
|
| 272 |
+
This creates:
|
| 273 |
+
- `scripts/test_datasets/*.parquet` - Test data (images, requests, expected responses)
|
| 274 |
+
- `scripts/test_datasets/*_metadata.json` - Human-readable descriptions
|
| 275 |
+
- `scripts/test_datasets/datasets_summary.json` - Overview of all datasets
|
| 276 |
+
|
| 277 |
+
### Run Tests
|
| 278 |
+
|
| 279 |
+
```bash
|
| 280 |
+
# Start your service first
|
| 281 |
+
make serve
|
| 282 |
+
```
|
| 283 |
+
|
| 284 |
+
In another terminal:
|
| 285 |
+
|
| 286 |
+
```bash
|
| 287 |
+
# Quick test (5 samples per dataset)
|
| 288 |
+
python scripts/test_datasets.py --quick
|
| 289 |
+
|
| 290 |
+
# Full validation
|
| 291 |
+
python scripts/test_datasets.py
|
| 292 |
+
|
| 293 |
+
# Test specific category
|
| 294 |
+
python scripts/test_datasets.py --category edge_case
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
### Dataset Categories (25 datasets each)
|
| 298 |
+
|
| 299 |
+
**1. Standard Tests** (`standard_test_*.parquet`)
|
| 300 |
+
- Normal images: random patterns, shapes, gradients
|
| 301 |
+
- Common sizes: 224x224, 256x256, 299x299, 384x384
|
| 302 |
+
- Formats: JPEG, PNG
|
| 303 |
+
- Purpose: Baseline validation
|
| 304 |
+
|
| 305 |
+
**2. Edge Cases** (`edge_case_*.parquet`)
|
| 306 |
+
- Tiny images (32x32, 1x1)
|
| 307 |
+
- Huge images (2048x2048)
|
| 308 |
+
- Extreme aspect ratios (1000x50)
|
| 309 |
+
- Corrupted data, malformed requests
|
| 310 |
+
- Purpose: Test error handling
|
| 311 |
+
|
| 312 |
+
**3. Performance Benchmarks** (`performance_test_*.parquet`)
|
| 313 |
+
- Batch sizes: 1, 5, 10, 25, 50, 100 images
|
| 314 |
+
- Latency and throughput tracking
|
| 315 |
+
- Purpose: Performance profiling
|
| 316 |
+
|
| 317 |
+
**4. Model Comparisons** (`model_comparison_*.parquet`)
|
| 318 |
+
- Same inputs across different architectures
|
| 319 |
+
- Models: ResNet-18/50, ViT, ConvNext, Swin
|
| 320 |
+
- Purpose: Cross-model benchmarking
|
| 321 |
+
|
| 322 |
+
### Test Output
|
| 323 |
+
|
| 324 |
+
```
|
| 325 |
+
DATASET TESTING SUMMARY
|
| 326 |
+
============================================================
|
| 327 |
+
Datasets tested: 100
|
| 328 |
+
Successful datasets: 95
|
| 329 |
+
Failed datasets: 5
|
| 330 |
+
Total samples: 1,247
|
| 331 |
+
Overall success rate: 87.3%
|
| 332 |
+
Test duration: 45.2s
|
| 333 |
+
|
| 334 |
+
Performance:
|
| 335 |
+
Avg latency: 123.4ms
|
| 336 |
+
Median latency: 98.7ms
|
| 337 |
+
p95 latency: 342.1ms
|
| 338 |
+
Max latency: 2,341.0ms
|
| 339 |
+
Requests/sec: 27.6
|
| 340 |
+
|
| 341 |
+
Category breakdown:
|
| 342 |
+
standard: 25 datasets, 94.2% avg success
|
| 343 |
+
edge_case: 25 datasets, 76.8% avg success
|
| 344 |
+
performance: 25 datasets, 91.1% avg success
|
| 345 |
+
model_comparison: 25 datasets, 89.3% avg success
|
| 346 |
+
```
|
| 347 |
+
|
| 348 |
+
## Common Issues
|
| 349 |
+
|
| 350 |
+
**Port 8000 already in use:**
|
| 351 |
+
```bash
|
| 352 |
+
# Find what's using it
|
| 353 |
+
lsof -i :8000
|
| 354 |
+
|
| 355 |
+
# Or just use a different port
|
| 356 |
+
uvicorn main:app --port 8080
|
| 357 |
+
```
|
| 358 |
+
|
| 359 |
+
**Model not loading:**
|
| 360 |
+
- Check the path: models should be in `models/<org>/<model-name>/`
|
| 361 |
+
- If you're trying to run the example ResNet-based model, make sure you ran `make download` to fetch the model weights.
|
| 362 |
+
- Check logs for the exact error
|
| 363 |
+
|
| 364 |
+
**Slow inference:**
|
| 365 |
+
- Inference runs on CPU by default
|
| 366 |
+
- For GPU: install CUDA PyTorch and modify service to use GPU device
|
| 367 |
+
- Consider using smaller models or quantization
|
| 368 |
+
|
| 369 |
+
## License
|
| 370 |
+
|
| 371 |
+
Apache 2.0
|
app/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ML Inference Service
|
| 3 |
+
|
| 4 |
+
A FastAPI-based web service for machine learning model inference.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
__version__ = "0.1.0"
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""API layer for the ML inference service."""
|
app/api/controllers.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API controllers for request handling and validation."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from fastapi import HTTPException
|
| 5 |
+
|
| 6 |
+
from app.core.logging import logger
|
| 7 |
+
from app.services.base import InferenceService
|
| 8 |
+
from app.api.models import ImageRequest, PredictionResponse
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PredictionController:
|
| 12 |
+
"""Controller for prediction endpoints."""
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
async def predict(
|
| 16 |
+
request: ImageRequest,
|
| 17 |
+
service: InferenceService
|
| 18 |
+
) -> PredictionResponse:
|
| 19 |
+
"""Run inference using the configured service."""
|
| 20 |
+
try:
|
| 21 |
+
if not service or not service.is_loaded:
|
| 22 |
+
raise HTTPException(503, "Service not available")
|
| 23 |
+
|
| 24 |
+
if not request.image.mediaType.startswith('image/'):
|
| 25 |
+
raise HTTPException(400, f"Invalid media type: {request.image.mediaType}")
|
| 26 |
+
|
| 27 |
+
return await asyncio.to_thread(service.predict, request)
|
| 28 |
+
|
| 29 |
+
except HTTPException:
|
| 30 |
+
raise
|
| 31 |
+
except ValueError as e:
|
| 32 |
+
logger.error(f"Invalid input: {e}")
|
| 33 |
+
raise HTTPException(400, str(e))
|
| 34 |
+
except Exception as e:
|
| 35 |
+
logger.error(f"Prediction failed: {e}")
|
| 36 |
+
raise HTTPException(500, "Internal server error")
|
app/api/models.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models for request/response validation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import base64
|
| 8 |
+
import enum
|
| 9 |
+
import io
|
| 10 |
+
import typing
|
| 11 |
+
from typing import Literal, Optional
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pydantic
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
if typing.TYPE_CHECKING:
|
| 18 |
+
from numpy.typing import NDArray
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ImageData(pydantic.BaseModel):
|
| 22 |
+
"""Image data model for base64 encoded images."""
|
| 23 |
+
mediaType: str = pydantic.Field(
|
| 24 |
+
description="The IETF Media Type (MIME type) of the data"
|
| 25 |
+
)
|
| 26 |
+
data: str = pydantic.Field(
|
| 27 |
+
description="A base64 string encoding of the data.",
|
| 28 |
+
# Canonical base64 encoding
|
| 29 |
+
# https://stackoverflow.com/a/64467300/3709935
|
| 30 |
+
pattern=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/][AQgw]==|[A-Za-z0-9+/]{2}[AEIMQUYcgkosw048]=)?$",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class BinaryMask(pydantic.BaseModel):
|
| 35 |
+
"""A bit mask indicating which pixels are manipulated / synthesized. A
|
| 36 |
+
pixel value of ``0`` means "no detection", and a value of ``1`` means
|
| 37 |
+
"detection".
|
| 38 |
+
|
| 39 |
+
The mask data must be encoded in PNG format, so that typical mask data is
|
| 40 |
+
compressed effectively. The PNG encoding **should** use "bilevel" mode for
|
| 41 |
+
maximum compactness. You can use the ``BinaryMask.from_numpy()``
|
| 42 |
+
function to convert a 0-1 numpy array to a BinaryMask.
|
| 43 |
+
"""
|
| 44 |
+
mediaType: Literal["image/png"] = pydantic.Field(
|
| 45 |
+
description="The IETF Media Type (MIME type) of the data."
|
| 46 |
+
)
|
| 47 |
+
data: str = pydantic.Field(
|
| 48 |
+
description="A base64 string encoding of the data.",
|
| 49 |
+
# Canonical base64 encoding
|
| 50 |
+
# https://stackoverflow.com/a/64467300/3709935
|
| 51 |
+
pattern=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/][AQgw]==|[A-Za-z0-9+/]{2}[AEIMQUYcgkosw048]=)?$",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def from_numpy(mask: NDArray[np.uint8]) -> BinaryMask:
|
| 56 |
+
"""Convert a 0-1 numpy array to a BinaryMask.
|
| 57 |
+
|
| 58 |
+
The numpy data must be in row-major order. That means the first
|
| 59 |
+
dimension corresponds to **height** and the second dimension corresponds
|
| 60 |
+
to **width**.
|
| 61 |
+
"""
|
| 62 |
+
# Convert to "L" (grayscale) then "1" (bilevel) for compact binary representation
|
| 63 |
+
mask_img = Image.fromarray(mask * 255, mode="L").convert("1", dither=None)
|
| 64 |
+
mask_img_buffer = io.BytesIO()
|
| 65 |
+
mask_img.save(mask_img_buffer, format="png")
|
| 66 |
+
mask_data = base64.b64encode(mask_img_buffer.getbuffer()).decode("utf-8")
|
| 67 |
+
return BinaryMask(mediaType="image/png", data=mask_data)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ImageRequest(pydantic.BaseModel):
|
| 71 |
+
"""Request model for image classification."""
|
| 72 |
+
image: ImageData
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class Labels(enum.IntEnum):
|
| 76 |
+
Natural = 0
|
| 77 |
+
FullySynthesized = 1
|
| 78 |
+
LocallyEdited = 2
|
| 79 |
+
LocallySynthesized = 3
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class PredictionResponse(pydantic.BaseModel):
|
| 83 |
+
"""Response model for synthetic image classification results.
|
| 84 |
+
|
| 85 |
+
Detector models will be scored primarily on their ability to classify the
|
| 86 |
+
entire image into 1 of the 4 label categories::
|
| 87 |
+
|
| 88 |
+
0: (Natural) The image is natural / unaltered.
|
| 89 |
+
1: (FullySynthesized) The entire image was synthesized by e.g., a
|
| 90 |
+
generative image model.
|
| 91 |
+
2: (LocallyEdited) The image is a natural image where a portion has
|
| 92 |
+
been edited using traditional photo editing techniques such as
|
| 93 |
+
splicing.
|
| 94 |
+
3: (LocallySynthesized) The image is a natural image where a portion
|
| 95 |
+
has been replaced by synthesized content.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
logprobs: list[float] = pydantic.Field(
|
| 99 |
+
description="The log-probabilities for each of the 4 possible labels.",
|
| 100 |
+
min_length=4,
|
| 101 |
+
max_length=4,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
localizationMask: Optional[BinaryMask] = pydantic.Field(
|
| 105 |
+
description="A bit mask localizing predicted edits. Models that are"
|
| 106 |
+
" not capable of localization may omit this field. It may also be"
|
| 107 |
+
" omitted if the predicted label is ``0`` or ``1``, in which case the"
|
| 108 |
+
" mask will be assumed to be all 0's or all 1's, as appropriate."
|
| 109 |
+
)
|
app/api/routes/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""API route definitions."""
|
app/api/routes/prediction.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prediction API routes."""
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends
|
| 4 |
+
|
| 5 |
+
from app.api.controllers import PredictionController
|
| 6 |
+
from app.api.models import ImageRequest, PredictionResponse
|
| 7 |
+
from app.core.dependencies import get_inference_service
|
| 8 |
+
from app.services.base import InferenceService
|
| 9 |
+
|
| 10 |
+
router = APIRouter()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@router.post("/predict", response_model=PredictionResponse)
|
| 14 |
+
async def predict(
|
| 15 |
+
request: ImageRequest,
|
| 16 |
+
service: InferenceService = Depends(get_inference_service)
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
Run inference on base64-encoded image.
|
| 20 |
+
|
| 21 |
+
Returns prediction, confidence, predicted label, model name, and media type.
|
| 22 |
+
"""
|
| 23 |
+
return await PredictionController.predict(request, service)
|
app/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Core utilities and configurations."""
|
app/core/app.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application factory and core infrastructure."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import warnings
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
from typing import AsyncGenerator
|
| 7 |
+
|
| 8 |
+
from fastapi import FastAPI
|
| 9 |
+
from pydantic import Field
|
| 10 |
+
from pydantic_settings import BaseSettings
|
| 11 |
+
|
| 12 |
+
from app.core.logging import logger
|
| 13 |
+
from app.core.dependencies import set_inference_service
|
| 14 |
+
from app.services.inference import ResNetInferenceService
|
| 15 |
+
from app.services.inference_clip import CLIPInferenceService
|
| 16 |
+
|
| 17 |
+
from app.services.inference_dino import DINOInferenceService
|
| 18 |
+
from app.api.routes import prediction
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Settings(BaseSettings):
|
| 25 |
+
"""Application settings. Override via environment variables or .env file."""
|
| 26 |
+
|
| 27 |
+
app_name: str = Field(default="ML Inference Service")
|
| 28 |
+
app_version: str = Field(default="0.1.0")
|
| 29 |
+
debug: bool = Field(default=False)
|
| 30 |
+
host: str = Field(default="0.0.0.0")
|
| 31 |
+
port: int = Field(default=8000)
|
| 32 |
+
|
| 33 |
+
class Config:
|
| 34 |
+
env_file = ".env"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
settings = Settings()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@asynccontextmanager
|
| 41 |
+
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
| 42 |
+
"""Application lifecycle: startup/shutdown."""
|
| 43 |
+
logger.info("Starting ML Inference Service...")
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
with warnings.catch_warnings():
|
| 47 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 48 |
+
|
| 49 |
+
# Replace ResNetInferenceService with your own implementation
|
| 50 |
+
# service = ResNetInferenceService(model_name="microsoft/resnet-18")
|
| 51 |
+
# service = CLIPInferenceService(model_name="clip")
|
| 52 |
+
service = DINOInferenceService(model_name = "dino")
|
| 53 |
+
await asyncio.to_thread(service.load_model)
|
| 54 |
+
set_inference_service(service)
|
| 55 |
+
|
| 56 |
+
logger.info("Startup completed successfully")
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
|
| 60 |
+
logger.error(f"Startup failed: {e}")
|
| 61 |
+
raise
|
| 62 |
+
|
| 63 |
+
yield
|
| 64 |
+
|
| 65 |
+
logger.info("Shutting down...")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def create_app() -> FastAPI:
|
| 69 |
+
"""Create and configure FastAPI application."""
|
| 70 |
+
app = FastAPI(
|
| 71 |
+
title=settings.app_name,
|
| 72 |
+
description="ML inference service for image classification",
|
| 73 |
+
version=settings.app_version,
|
| 74 |
+
debug=settings.debug,
|
| 75 |
+
lifespan=lifespan
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
app.include_router(prediction.router)
|
| 79 |
+
|
| 80 |
+
return app
|
app/core/dependencies.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dependency injection for services."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from app.services.base import InferenceService
|
| 5 |
+
|
| 6 |
+
_inference_service: Optional[InferenceService] = None
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_inference_service() -> Optional[InferenceService]:
|
| 10 |
+
"""Get inference service for dependency injection."""
|
| 11 |
+
return _inference_service
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def set_inference_service(service: InferenceService) -> None:
|
| 15 |
+
"""Set inference service. Called internally during startup."""
|
| 16 |
+
global _inference_service
|
| 17 |
+
_inference_service = service
|
app/core/logging.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging configuration."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def setup_logging(logger_name: str = "ML Inference Service") -> logging.Logger:
|
| 8 |
+
"""Setup and configure logger."""
|
| 9 |
+
logger = logging.getLogger(logger_name)
|
| 10 |
+
|
| 11 |
+
if logger.handlers:
|
| 12 |
+
return logger
|
| 13 |
+
|
| 14 |
+
logger.setLevel(logging.INFO)
|
| 15 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 16 |
+
handler.setLevel(logging.INFO)
|
| 17 |
+
formatter = logging.Formatter(
|
| 18 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 19 |
+
)
|
| 20 |
+
handler.setFormatter(formatter)
|
| 21 |
+
logger.addHandler(handler)
|
| 22 |
+
|
| 23 |
+
return logger
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = setup_logging()
|
app/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Business logic services."""
|
app/services/base.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Abstract base class for ML inference services."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Generic, TypeVar
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
|
| 7 |
+
TRequest = TypeVar('TRequest', bound=BaseModel)
|
| 8 |
+
TResponse = TypeVar('TResponse', bound=BaseModel)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class InferenceService(ABC, Generic[TRequest, TResponse]):
|
| 12 |
+
"""
|
| 13 |
+
Base class for inference services. Subclass this to integrate your model.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def load_model(self) -> None:
|
| 18 |
+
"""Load model weights and processors. Called once at startup."""
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def predict(self, request: TRequest) -> TResponse:
|
| 23 |
+
"""Run inference and return typed response."""
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def is_loaded(self) -> bool:
|
| 29 |
+
"""Check if model is loaded and ready."""
|
| 30 |
+
pass
|
app/services/inference.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ResNet inference service implementation."""
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import os
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
|
| 11 |
+
|
| 12 |
+
from app.core.logging import logger
|
| 13 |
+
from app.services.base import InferenceService
|
| 14 |
+
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
|
| 18 |
+
"""ResNet-18 inference service for image classification."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, model_name: str = "microsoft/resnet-18"):
|
| 21 |
+
self.model_name = model_name
|
| 22 |
+
self.model = None
|
| 23 |
+
self.processor = None
|
| 24 |
+
self._is_loaded = False
|
| 25 |
+
self.model_path = os.path.join("models", model_name)
|
| 26 |
+
logger.info(f"Initializing ResNet service: {self.model_path}")
|
| 27 |
+
|
| 28 |
+
def load_model(self) -> None:
|
| 29 |
+
if self._is_loaded:
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
if not os.path.exists(self.model_path):
|
| 33 |
+
raise FileNotFoundError(f"Model not found: {self.model_path}")
|
| 34 |
+
|
| 35 |
+
config_path = os.path.join(self.model_path, "config.json")
|
| 36 |
+
if not os.path.exists(config_path):
|
| 37 |
+
raise FileNotFoundError(f"Config not found: {config_path}")
|
| 38 |
+
|
| 39 |
+
logger.info(f"Loading model from {self.model_path}")
|
| 40 |
+
|
| 41 |
+
import warnings
|
| 42 |
+
with warnings.catch_warnings():
|
| 43 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 44 |
+
self.processor = AutoImageProcessor.from_pretrained(
|
| 45 |
+
self.model_path, local_files_only=True
|
| 46 |
+
)
|
| 47 |
+
self.model = ResNetForImageClassification.from_pretrained(
|
| 48 |
+
self.model_path, local_files_only=True
|
| 49 |
+
)
|
| 50 |
+
assert self.model is not None
|
| 51 |
+
|
| 52 |
+
self._is_loaded = True
|
| 53 |
+
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
|
| 54 |
+
|
| 55 |
+
def predict(self, request: ImageRequest) -> PredictionResponse:
|
| 56 |
+
if not self.is_loaded:
|
| 57 |
+
raise RuntimeError("model is not loaded")
|
| 58 |
+
assert self.processor is not None
|
| 59 |
+
assert self.model is not None
|
| 60 |
+
|
| 61 |
+
image_data = base64.b64decode(request.image.data)
|
| 62 |
+
image = Image.open(BytesIO(image_data))
|
| 63 |
+
|
| 64 |
+
if image.mode != 'RGB':
|
| 65 |
+
image = image.convert('RGB')
|
| 66 |
+
|
| 67 |
+
inputs = self.processor(image, return_tensors="pt")
|
| 68 |
+
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
logits = self.model(**inputs).logits.squeeze() # pyright: ignore
|
| 71 |
+
|
| 72 |
+
# Convert the ImageNet output vector of dimension 1000 to the expected
|
| 73 |
+
# output format. This is for demonstration purposes
|
| 74 |
+
# and obviously will not perform well on the actual task.
|
| 75 |
+
logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist()
|
| 76 |
+
|
| 77 |
+
# Dummy localization mask: a rectangle approximately in the middle
|
| 78 |
+
x = image.width // 3
|
| 79 |
+
y = image.height // 3
|
| 80 |
+
# Row-major order
|
| 81 |
+
mask = np.zeros((image.height, image.width), dtype=np.uint8)
|
| 82 |
+
mask[y:(2*y), x:(2*x)] = 1
|
| 83 |
+
mask_obj = BinaryMask.from_numpy(mask)
|
| 84 |
+
|
| 85 |
+
return PredictionResponse(
|
| 86 |
+
logprobs=logprobs,
|
| 87 |
+
localizationMask=mask_obj,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def is_loaded(self) -> bool:
|
| 92 |
+
return self._is_loaded
|
app/services/inference_clip.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ResNet inference service implementation."""
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import os
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
|
| 12 |
+
|
| 13 |
+
from app.core.logging import logger
|
| 14 |
+
from app.services.base import InferenceService
|
| 15 |
+
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse
|
| 16 |
+
|
| 17 |
+
from app.services.models.alignment_pretrained.unet import UNetImageDecoder
|
| 18 |
+
from app.services.models.CLIPSvD import CLIPSvD
|
| 19 |
+
from app.services.models.alignment_pretrained.model_with_bce_images_blip import MMModerator
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CLIPDINOProcessor:
|
| 23 |
+
"""Image processor for CLIP and DINO models with 224x224 resizing and normalization."""
|
| 24 |
+
|
| 25 |
+
def __init__(self, image_size: int = 224):
|
| 26 |
+
self.image_size = image_size
|
| 27 |
+
# Standard ImageNet normalization used by both CLIP and DINO
|
| 28 |
+
self.transform = transforms.Compose([
|
| 29 |
+
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
|
| 30 |
+
transforms.ToTensor(),
|
| 31 |
+
transforms.Normalize(
|
| 32 |
+
mean=[0.485, 0.456, 0.406], # ImageNet mean
|
| 33 |
+
std=[0.229, 0.224, 0.225] # ImageNet std
|
| 34 |
+
)
|
| 35 |
+
])
|
| 36 |
+
|
| 37 |
+
def __call__(self, image: Image.Image, return_tensors: str = "pt"):
|
| 38 |
+
"""
|
| 39 |
+
Process an image for CLIP/DINO input.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
image: PIL Image object
|
| 43 |
+
return_tensors: Format of returned tensors (default: "pt" for PyTorch)
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Dictionary with 'pixel_values' key containing the processed tensor
|
| 47 |
+
"""
|
| 48 |
+
if not isinstance(image, Image.Image):
|
| 49 |
+
raise ValueError("Input must be a PIL Image")
|
| 50 |
+
|
| 51 |
+
# Apply transforms
|
| 52 |
+
pixel_values = self.transform(image)
|
| 53 |
+
|
| 54 |
+
# Add batch dimension if needed
|
| 55 |
+
if pixel_values.dim() == 3:
|
| 56 |
+
pixel_values = pixel_values.unsqueeze(0)
|
| 57 |
+
|
| 58 |
+
return {"pixel_values": pixel_values}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def create_vision_encoder():
|
| 62 |
+
REPO_DIR = "/media/NAS/USERS/shahid/MultimodalAudioVisualModerator/dinov3"
|
| 63 |
+
model = CLIPSvD()
|
| 64 |
+
return model
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class CLIPInferenceService(InferenceService[ImageRequest, PredictionResponse]):
|
| 69 |
+
"""ResNet-18 inference service for image classification."""
|
| 70 |
+
|
| 71 |
+
def __init__(self, model_name: str = "microsoft/resnet-18"):
|
| 72 |
+
self.model_name = model_name
|
| 73 |
+
# self.model = None
|
| 74 |
+
self.processor = CLIPDINOProcessor(image_size=224)
|
| 75 |
+
self._is_loaded = False
|
| 76 |
+
self.model_path = os.path.join("models", model_name)
|
| 77 |
+
|
| 78 |
+
pretraining = False
|
| 79 |
+
num_classes = 4
|
| 80 |
+
self.vision_encoder = create_vision_encoder()
|
| 81 |
+
# self.vision_encoder.to(device=device, dtype=torch.float32)
|
| 82 |
+
self.unet_decoder = UNetImageDecoder(
|
| 83 |
+
num_patches=256, # 7 × 7 grid (ViT-B/32)
|
| 84 |
+
token_dim=1024, # ViT-B/32 embedding dim
|
| 85 |
+
out_channels=3, # mask or 3 for RGB
|
| 86 |
+
base_channels=256,
|
| 87 |
+
img_size=256,
|
| 88 |
+
grid_hw=(16, 16) # explicitly set to match patch grid
|
| 89 |
+
)
|
| 90 |
+
self.model = MMModerator(pretraining=pretraining,vision_encoder=self.vision_encoder, unet_decoder=self.unet_decoder, num_classes=num_classes)
|
| 91 |
+
logger.info(f"Initializing CLIP service: {self.model_path}")
|
| 92 |
+
|
| 93 |
+
def load_model(self) -> None:
|
| 94 |
+
if self._is_loaded:
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
if not os.path.exists(self.model_path):
|
| 98 |
+
raise FileNotFoundError(f"Model not found: {self.model_path}")
|
| 99 |
+
|
| 100 |
+
config_path = os.path.join(self.model_path, "config.json")
|
| 101 |
+
# if not os.path.exists(config_path):
|
| 102 |
+
# raise FileNotFoundError(f"Config not found: {config_path}")
|
| 103 |
+
|
| 104 |
+
logger.info(f"Loading model from {self.model_path}")
|
| 105 |
+
checkpoint_path = os.path.join(self.model_path, "model_state.pt")
|
| 106 |
+
checkpoint_path_encoder = os.path.join(self.model_path, "model_state_encoder.pt")
|
| 107 |
+
checkpoint_path_decoder = os.path.join(self.model_path, "model_state_decoder.pt")
|
| 108 |
+
raw = torch.load(checkpoint_path)
|
| 109 |
+
raw_encoder = torch.load(checkpoint_path_encoder)
|
| 110 |
+
raw_decoder = torch.load(checkpoint_path_decoder)
|
| 111 |
+
sd = raw.get("model_state_dict", raw)
|
| 112 |
+
new_sd = {}
|
| 113 |
+
for k, v in sd.items():
|
| 114 |
+
new_key = k.replace("module.", "")
|
| 115 |
+
new_sd[new_key] = v
|
| 116 |
+
|
| 117 |
+
self.model.load_state_dict(new_sd)
|
| 118 |
+
sd_encoder = raw_encoder.get("model_state_dict", raw_encoder)
|
| 119 |
+
new_sd_encoder = {}
|
| 120 |
+
for k, v in sd_encoder.items():
|
| 121 |
+
new_key = k.replace("module.", "")
|
| 122 |
+
new_sd_encoder[new_key] = v
|
| 123 |
+
self.vision_encoder.load_state_dict(new_sd_encoder)
|
| 124 |
+
sd_decoder = raw_decoder.get("model_state_dict", raw_decoder)
|
| 125 |
+
new_sd_decoder = {}
|
| 126 |
+
for k, v in sd_decoder.items():
|
| 127 |
+
new_key = k.replace("module.", "")
|
| 128 |
+
new_sd_decoder[new_key] = v
|
| 129 |
+
self.unet_decoder.load_state_dict(new_sd_decoder) # strict=True by default
|
| 130 |
+
|
| 131 |
+
self._is_loaded = True
|
| 132 |
+
# logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
|
| 133 |
+
|
| 134 |
+
def predict(self, request: ImageRequest) -> PredictionResponse:
|
| 135 |
+
if not self.is_loaded:
|
| 136 |
+
raise RuntimeError("model is not loaded")
|
| 137 |
+
|
| 138 |
+
assert self.processor is not None
|
| 139 |
+
|
| 140 |
+
assert self.model is not None
|
| 141 |
+
image_data = base64.b64decode(request.image.data)
|
| 142 |
+
image = Image.open(BytesIO(image_data))
|
| 143 |
+
if image.mode != 'RGB':
|
| 144 |
+
image = image.convert('RGB')
|
| 145 |
+
|
| 146 |
+
processed = self.processor(image, return_tensors="pt")
|
| 147 |
+
pixel_values = processed["pixel_values"]
|
| 148 |
+
|
| 149 |
+
self.model.eval()
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
# logits, losses, label, image_recon, overlay =
|
| 152 |
+
logits_cls,logits, losses, labels_expanded,data_labels_expanded, image_recon, overlay = self.model(images=pixel_values)
|
| 153 |
+
|
| 154 |
+
logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist()[0]
|
| 155 |
+
|
| 156 |
+
# x = image.width // 3
|
| 157 |
+
# y = image.height // 3
|
| 158 |
+
# # Row-major order
|
| 159 |
+
# mask = np.zeros((image.height, image.width), dtype=np.uint8)
|
| 160 |
+
# mask[y:(2*y), x:(2*x)] = 1
|
| 161 |
+
# mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
|
| 162 |
+
# mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
|
| 163 |
+
print(logprobs)
|
| 164 |
+
return PredictionResponse(
|
| 165 |
+
logprobs=logprobs,
|
| 166 |
+
localizationMask=None,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def is_loaded(self) -> bool:
|
| 171 |
+
return self._is_loaded
|
app/services/inference_dino.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ResNet inference service implementation."""
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import os
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
|
| 12 |
+
|
| 13 |
+
from app.core.logging import logger
|
| 14 |
+
from app.services.base import InferenceService
|
| 15 |
+
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse
|
| 16 |
+
|
| 17 |
+
from app.services.models.alignment_pretrained.unet import UNetImageDecoder
|
| 18 |
+
|
| 19 |
+
from app.services.models.DinoLORA import DINOEncoderLoRA
|
| 20 |
+
from app.services.models.alignment_pretrained.model_with_bce_images_dino import MMModerator
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DINODINOProcessor:
|
| 24 |
+
"""Image processor for DINO and DINO models with 224x224 resizing and normalization."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, image_size: int = 512):
|
| 27 |
+
self.image_size = image_size
|
| 28 |
+
# Standard ImageNet normalization used by both DINO and DINO
|
| 29 |
+
self.transform = transforms.Compose([
|
| 30 |
+
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
|
| 31 |
+
transforms.ToTensor(),
|
| 32 |
+
transforms.Normalize(
|
| 33 |
+
mean=[0.485, 0.456, 0.406], # ImageNet mean
|
| 34 |
+
std=[0.229, 0.224, 0.225] # ImageNet std
|
| 35 |
+
)
|
| 36 |
+
])
|
| 37 |
+
|
| 38 |
+
def __call__(self, image: Image.Image, return_tensors: str = "pt"):
|
| 39 |
+
"""
|
| 40 |
+
Process an image for DINO/DINO input.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
image: PIL Image object
|
| 44 |
+
return_tensors: Format of returned tensors (default: "pt" for PyTorch)
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Dictionary with 'pixel_values' key containing the processed tensor
|
| 48 |
+
"""
|
| 49 |
+
if not isinstance(image, Image.Image):
|
| 50 |
+
raise ValueError("Input must be a PIL Image")
|
| 51 |
+
|
| 52 |
+
# Apply transforms
|
| 53 |
+
pixel_values = self.transform(image)
|
| 54 |
+
|
| 55 |
+
# Add batch dimension if needed
|
| 56 |
+
if pixel_values.dim() == 3:
|
| 57 |
+
pixel_values = pixel_values.unsqueeze(0)
|
| 58 |
+
|
| 59 |
+
return {"pixel_values": pixel_values}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_vision_encoder():
|
| 63 |
+
REPO_DIR = "/media/NAS/USERS/shahid/MultimodalAudioVisualModerator/dinov3"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
encoder = torch.hub.load(REPO_DIR, 'dinov3_vitb16', source='local', weights="https://dinov3.llamameta.net/dinov3_vitb16/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoib3VpbXR2cHlhZXE5c2JwajNucnN3aWF2IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NTk5MDI0NDF9fX1dfQ__&Signature=GIYJR4%7ESJVx0gkcm7lgAvDljIfpR30MXgWb2VpCqbDeVpnwjn97k%7EOcPeGF-lkR0q1Sn3Iw5Y3iYWqspcpPoDJ4FXUmMKhWtd-m00HO73Aknq2kyrKVMBpzwQB-k-2zZe7okJfXTj46EWbzu9mNcxt%7ErDPe7phQpRJi8Dleida1BJ823oXFx8d7oRSa4NDSzT2TNXqNNZ8ux7N0aDfdT9dupEeEr4AP06LhYB2I7kF%7Ef4bvKQsKnlPMVDAADyYG9nQ7HqAW41LaWZtR-BrDGm%7ESNu-6L44cUVnk3qEPVRMQB4GW7ixRGGhtr37F6HVz%7EKilrCpivFD6ej4reNUWaGQ__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=829796286371500")
|
| 67 |
+
model = DINOEncoderLoRA(encoder, r=16, emb_dim=1024, use_lora=True)
|
| 68 |
+
return model
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class DINOInferenceService(InferenceService[ImageRequest, PredictionResponse]):
|
| 73 |
+
"""ResNet-18 inference service for image classification."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, model_name: str = "microsoft/resnet-18"):
|
| 76 |
+
self.model_name = model_name
|
| 77 |
+
# self.model = None
|
| 78 |
+
self.processor = DINODINOProcessor(image_size=512)
|
| 79 |
+
self._is_loaded = False
|
| 80 |
+
self.model_path = os.path.join("models", model_name)
|
| 81 |
+
|
| 82 |
+
pretraining = False
|
| 83 |
+
num_classes = 4
|
| 84 |
+
self.vision_encoder = create_vision_encoder()
|
| 85 |
+
# self.vision_encoder.to(device=device, dtype=torch.float32)
|
| 86 |
+
|
| 87 |
+
self.unet_decoder = UNetImageDecoder(
|
| 88 |
+
num_patches=1024, # MUST match N=1024
|
| 89 |
+
token_dim=768, # because tokens are [B,1024,768]
|
| 90 |
+
out_channels=1, # RGB reconstructed output
|
| 91 |
+
base_channels=128, # recommended for 512px
|
| 92 |
+
img_size=512, # image resolution
|
| 93 |
+
grid_hw=(32, 32) # MUST match N=1024
|
| 94 |
+
)
|
| 95 |
+
self.model = MMModerator(pretraining=pretraining,vision_encoder=self.vision_encoder, unet_decoder=self.unet_decoder, num_classes=num_classes)
|
| 96 |
+
self.model.eval()
|
| 97 |
+
logger.info(f"Initializing DINO service: {self.model_path}")
|
| 98 |
+
|
| 99 |
+
def load_model(self) -> None:
|
| 100 |
+
if self._is_loaded:
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
if not os.path.exists(self.model_path):
|
| 104 |
+
raise FileNotFoundError(f"Model not found: {self.model_path}")
|
| 105 |
+
|
| 106 |
+
config_path = os.path.join(self.model_path, "config.json")
|
| 107 |
+
# if not os.path.exists(config_path):
|
| 108 |
+
# raise FileNotFoundError(f"Config not found: {config_path}")
|
| 109 |
+
|
| 110 |
+
logger.info(f"Loading model from {self.model_path}")
|
| 111 |
+
checkpoint_path = os.path.join(self.model_path, "model_state.pt")
|
| 112 |
+
checkpoint_path_encoder = os.path.join(self.model_path, "model_state_encoder.pt")
|
| 113 |
+
checkpoint_path_decoder = os.path.join(self.model_path, "model_state_decoder.pt")
|
| 114 |
+
raw = torch.load(checkpoint_path)
|
| 115 |
+
raw_encoder = torch.load(checkpoint_path_encoder)
|
| 116 |
+
raw_decoder = torch.load(checkpoint_path_decoder)
|
| 117 |
+
sd = raw.get("model_state_dict", raw)
|
| 118 |
+
new_sd = {}
|
| 119 |
+
for k, v in sd.items():
|
| 120 |
+
new_key = k.replace("module.", "")
|
| 121 |
+
new_sd[new_key] = v
|
| 122 |
+
|
| 123 |
+
self.model.load_state_dict(new_sd)
|
| 124 |
+
sd_encoder = raw_encoder.get("model_state_dict", raw_encoder)
|
| 125 |
+
new_sd_encoder = {}
|
| 126 |
+
for k, v in sd_encoder.items():
|
| 127 |
+
new_key = k.replace("module.", "")
|
| 128 |
+
new_sd_encoder[new_key] = v
|
| 129 |
+
self.vision_encoder.load_state_dict(new_sd_encoder)
|
| 130 |
+
sd_decoder = raw_decoder.get("model_state_dict", raw_decoder)
|
| 131 |
+
new_sd_decoder = {}
|
| 132 |
+
for k, v in sd_decoder.items():
|
| 133 |
+
new_key = k.replace("module.", "")
|
| 134 |
+
new_sd_decoder[new_key] = v
|
| 135 |
+
self.unet_decoder.load_state_dict(new_sd_decoder) # strict=True by default
|
| 136 |
+
|
| 137 |
+
self._is_loaded = True
|
| 138 |
+
# logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
|
| 139 |
+
|
| 140 |
+
def predict(self, request: ImageRequest) -> PredictionResponse:
|
| 141 |
+
if not self.is_loaded:
|
| 142 |
+
raise RuntimeError("model is not loaded")
|
| 143 |
+
|
| 144 |
+
assert self.processor is not None
|
| 145 |
+
|
| 146 |
+
assert self.model is not None
|
| 147 |
+
image_data = base64.b64decode(request.image.data)
|
| 148 |
+
image = Image.open(BytesIO(image_data))
|
| 149 |
+
if image.mode != 'RGB':
|
| 150 |
+
image = image.convert('RGB')
|
| 151 |
+
|
| 152 |
+
processed = self.processor(image, return_tensors="pt")
|
| 153 |
+
pixel_values = processed["pixel_values"]
|
| 154 |
+
|
| 155 |
+
self.model.eval()
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
# logits, losses, label, image_recon, overlay =
|
| 158 |
+
logits,logits_multi_cls, losses, label,data_label, image_recon, overlay, shuffled_images, gt_masks = self.model(images=pixel_values)
|
| 159 |
+
|
| 160 |
+
logprobs = torch.nn.functional.log_softmax(logits_multi_cls[:len(Labels)]).tolist()[0]
|
| 161 |
+
|
| 162 |
+
# x = image.width // 3
|
| 163 |
+
# y = image.height // 3
|
| 164 |
+
# # Row-major order
|
| 165 |
+
# mask = np.zeros((image.height, image.width), dtype=np.uint8)
|
| 166 |
+
# mask[y:(2*y), x:(2*x)] = 1
|
| 167 |
+
# mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
|
| 168 |
+
# mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
|
| 169 |
+
print(logprobs)
|
| 170 |
+
return PredictionResponse(
|
| 171 |
+
logprobs=logprobs,
|
| 172 |
+
localizationMask=None,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def is_loaded(self) -> bool:
|
| 177 |
+
return self._is_loaded
|
cat.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
challenge-cli.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: 2025 UL Research Institutes
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import functools
|
| 5 |
+
import time
|
| 6 |
+
from datetime import datetime, timedelta, timezone
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import click
|
| 10 |
+
import httpx
|
| 11 |
+
|
| 12 |
+
from dyff.client import Client, errors
|
| 13 |
+
from dyff.schema.platform import *
|
| 14 |
+
from dyff.schema.requests import *
|
| 15 |
+
|
| 16 |
+
from app.api.models import PredictionResponse
|
| 17 |
+
|
| 18 |
+
# ----------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _wait_for_status(
|
| 22 |
+
get_entity_fn, target_status: str | list[str], *, timeout: timedelta
|
| 23 |
+
) -> str:
|
| 24 |
+
if isinstance(target_status, str):
|
| 25 |
+
target_status = [target_status]
|
| 26 |
+
then = datetime.now(timezone.utc)
|
| 27 |
+
while True:
|
| 28 |
+
try:
|
| 29 |
+
status = get_entity_fn().status
|
| 30 |
+
if status in target_status:
|
| 31 |
+
return status
|
| 32 |
+
except errors.HTTPError as ex:
|
| 33 |
+
if ex.status != 404:
|
| 34 |
+
raise
|
| 35 |
+
except httpx.HTTPStatusError as ex:
|
| 36 |
+
if ex.response.status_code != 404:
|
| 37 |
+
raise
|
| 38 |
+
if (datetime.now(timezone.utc) - then) >= timeout:
|
| 39 |
+
break
|
| 40 |
+
time.sleep(5)
|
| 41 |
+
raise AssertionError("timeout")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _common_options(f):
|
| 45 |
+
@click.option(
|
| 46 |
+
"--account",
|
| 47 |
+
type=str,
|
| 48 |
+
required=True,
|
| 49 |
+
help="Your account ID",
|
| 50 |
+
metavar="ID",
|
| 51 |
+
)
|
| 52 |
+
@functools.wraps(f)
|
| 53 |
+
def wrapper(*args, **kwargs):
|
| 54 |
+
return f(*args, **kwargs)
|
| 55 |
+
return wrapper
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@click.group()
|
| 59 |
+
def cli():
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@cli.command()
|
| 64 |
+
@_common_options
|
| 65 |
+
@click.option(
|
| 66 |
+
"--name",
|
| 67 |
+
type=str,
|
| 68 |
+
required=True,
|
| 69 |
+
help="The name of your detector model. For display and querying purposes only.",
|
| 70 |
+
)
|
| 71 |
+
@click.option(
|
| 72 |
+
"--image",
|
| 73 |
+
type=str,
|
| 74 |
+
default=None,
|
| 75 |
+
help="The Docker image to upload (e.g., 'some/image:latest')."
|
| 76 |
+
" Must exist in your local Docker deamon."
|
| 77 |
+
" Required if --artifact is not specified.",
|
| 78 |
+
)
|
| 79 |
+
@click.option(
|
| 80 |
+
"--endpoint",
|
| 81 |
+
type=str,
|
| 82 |
+
default="predict",
|
| 83 |
+
help="The endpoint to call on your service to make a prediction.",
|
| 84 |
+
)
|
| 85 |
+
@click.option(
|
| 86 |
+
"--volume",
|
| 87 |
+
type=click.Path(exists=True, file_okay=False, readable=True, resolve_path=True, path_type=Path),
|
| 88 |
+
default=None,
|
| 89 |
+
help="A local directory path containing files to upload and mount in the running Docker container."
|
| 90 |
+
" You should use this if your submission includes large files like neural network weights."
|
| 91 |
+
)
|
| 92 |
+
@click.option(
|
| 93 |
+
"--volume-mount",
|
| 94 |
+
type=click.Path(exists=False, path_type=Path),
|
| 95 |
+
default=None,
|
| 96 |
+
help="The path to mount your uploaded directory in the running Docker container."
|
| 97 |
+
" Must be an absolute path."
|
| 98 |
+
" Required if --volume is specified.")
|
| 99 |
+
@click.option(
|
| 100 |
+
"--artifact",
|
| 101 |
+
"artifact_id",
|
| 102 |
+
type=str,
|
| 103 |
+
default=None,
|
| 104 |
+
help="The ID of the Artifact (i.e., Docker image) to use in the submission, if it already exists."
|
| 105 |
+
" You can pass the artifact.id from a previous invocation.",
|
| 106 |
+
metavar="ID",
|
| 107 |
+
)
|
| 108 |
+
@click.option(
|
| 109 |
+
"--model",
|
| 110 |
+
"model_id",
|
| 111 |
+
type=str,
|
| 112 |
+
default=None,
|
| 113 |
+
help="The ID of the Model (i.e., neural network weights) to use in the submission, if it already exists."
|
| 114 |
+
" You can pass the model.id from a previous invocation.",
|
| 115 |
+
metavar="ID",
|
| 116 |
+
)
|
| 117 |
+
@click.option(
|
| 118 |
+
"--gpu",
|
| 119 |
+
is_flag=True,
|
| 120 |
+
default=False,
|
| 121 |
+
help="Request a GPU (NVIDIA L4) for the inference service.",
|
| 122 |
+
)
|
| 123 |
+
def upload_submission(
|
| 124 |
+
account: str,
|
| 125 |
+
name: str,
|
| 126 |
+
image: str | None,
|
| 127 |
+
endpoint: str,
|
| 128 |
+
volume: Path | None,
|
| 129 |
+
volume_mount: Path | None,
|
| 130 |
+
artifact_id: str | None,
|
| 131 |
+
model_id: str | None,
|
| 132 |
+
gpu: bool,
|
| 133 |
+
) -> None:
|
| 134 |
+
dyffapi = Client()
|
| 135 |
+
|
| 136 |
+
# Upload the image
|
| 137 |
+
if artifact_id is None:
|
| 138 |
+
# Create an Artifact resource
|
| 139 |
+
click.echo(f"creating Artifact ... {account}")
|
| 140 |
+
artifact = dyffapi.artifacts.create(ArtifactCreateRequest(account=account))
|
| 141 |
+
click.echo(f"artifact.id: \"{artifact.id}\"")
|
| 142 |
+
_wait_for_status(
|
| 143 |
+
lambda: dyffapi.artifacts.get(artifact.id),
|
| 144 |
+
"WaitingForUpload",
|
| 145 |
+
timeout=timedelta(seconds=30),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Push the image from the local Docker daemon
|
| 149 |
+
click.echo("pushing Artifact ...")
|
| 150 |
+
dyffapi.artifacts.push(artifact, source=f"docker-daemon:{image}")
|
| 151 |
+
time.sleep(5)
|
| 152 |
+
|
| 153 |
+
# Indicate that we're done pushing
|
| 154 |
+
dyffapi.artifacts.finalize(artifact.id)
|
| 155 |
+
_wait_for_status(
|
| 156 |
+
lambda: dyffapi.artifacts.get(artifact.id),
|
| 157 |
+
"Ready",
|
| 158 |
+
timeout=timedelta(seconds=30),
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
click.echo("... done")
|
| 162 |
+
else:
|
| 163 |
+
artifact = dyffapi.artifacts.get(artifact_id)
|
| 164 |
+
assert artifact is not None
|
| 165 |
+
|
| 166 |
+
model: Model | None = None
|
| 167 |
+
if model_id is None:
|
| 168 |
+
if volume is not None:
|
| 169 |
+
if volume_mount is None:
|
| 170 |
+
raise click.UsageError("--volume-mount is required when --volume is used")
|
| 171 |
+
|
| 172 |
+
click.echo("creating Model from local directory ...")
|
| 173 |
+
|
| 174 |
+
model = dyffapi.models.create_from_volume(
|
| 175 |
+
volume, name="model_volume", account=account, resources=ModelResources()
|
| 176 |
+
)
|
| 177 |
+
click.echo(f"model.id: \"{model.id}\"")
|
| 178 |
+
_wait_for_status(
|
| 179 |
+
lambda: dyffapi.models.get(model.id),
|
| 180 |
+
"WaitingForUpload",
|
| 181 |
+
timeout=timedelta(seconds=30),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
click.echo("uploading Model ...")
|
| 185 |
+
dyffapi.models.upload_volume(model, volume)
|
| 186 |
+
_wait_for_status(
|
| 187 |
+
lambda: dyffapi.models.get(model.id),
|
| 188 |
+
"Ready",
|
| 189 |
+
timeout=timedelta(seconds=30),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
click.echo("... done")
|
| 193 |
+
else:
|
| 194 |
+
model = None
|
| 195 |
+
else:
|
| 196 |
+
model = dyffapi.models.get(model_id)
|
| 197 |
+
assert model is not None
|
| 198 |
+
|
| 199 |
+
# Create a runnable InferenceService
|
| 200 |
+
if volume_mount is not None:
|
| 201 |
+
if model is None:
|
| 202 |
+
raise click.UsageError("--volume-mount requires --volume or --model")
|
| 203 |
+
if not volume_mount.is_absolute():
|
| 204 |
+
raise click.UsageError("--volume-mount must be an absolute path")
|
| 205 |
+
volumeMounts=[
|
| 206 |
+
VolumeMount(
|
| 207 |
+
kind=VolumeMountKind.data,
|
| 208 |
+
name="model",
|
| 209 |
+
mountPath=volume_mount,
|
| 210 |
+
data=VolumeMountData(
|
| 211 |
+
source=EntityIdentifier.of(model),
|
| 212 |
+
),
|
| 213 |
+
),
|
| 214 |
+
]
|
| 215 |
+
else:
|
| 216 |
+
volumeMounts = None
|
| 217 |
+
|
| 218 |
+
accelerator: Accelerator | None = None
|
| 219 |
+
if gpu:
|
| 220 |
+
accelerator = Accelerator(
|
| 221 |
+
kind="GPU",
|
| 222 |
+
gpu=AcceleratorGPU(
|
| 223 |
+
hardwareTypes=["nvidia.com/gpu-l4"],
|
| 224 |
+
count=1,
|
| 225 |
+
),
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Don't change this
|
| 229 |
+
service_request = InferenceServiceCreateRequest(
|
| 230 |
+
account=account,
|
| 231 |
+
name=name,
|
| 232 |
+
model=None,
|
| 233 |
+
runner=InferenceServiceRunner(
|
| 234 |
+
kind=InferenceServiceRunnerKind.CONTAINER,
|
| 235 |
+
imageRef=EntityIdentifier.of(artifact),
|
| 236 |
+
resources=ModelResources(),
|
| 237 |
+
volumeMounts=volumeMounts,
|
| 238 |
+
accelerator=accelerator,
|
| 239 |
+
),
|
| 240 |
+
interface=InferenceInterface(
|
| 241 |
+
endpoint=endpoint,
|
| 242 |
+
outputSchema=DataSchema.make_output_schema(PredictionResponse),
|
| 243 |
+
),
|
| 244 |
+
)
|
| 245 |
+
click.echo("creating InferenceService ...")
|
| 246 |
+
service = dyffapi.inferenceservices.create(service_request)
|
| 247 |
+
click.echo(f"service.id: \"{service.id}\"")
|
| 248 |
+
click.echo("... done")
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@cli.command()
|
| 252 |
+
@_common_options
|
| 253 |
+
@click.option(
|
| 254 |
+
"--task",
|
| 255 |
+
"task_id",
|
| 256 |
+
type=str,
|
| 257 |
+
required=True,
|
| 258 |
+
help="The Task ID to submit to.",
|
| 259 |
+
metavar="ID",
|
| 260 |
+
)
|
| 261 |
+
@click.option(
|
| 262 |
+
"--team",
|
| 263 |
+
"team_id",
|
| 264 |
+
type=str,
|
| 265 |
+
required=True,
|
| 266 |
+
help="The Team ID making the submission.",
|
| 267 |
+
metavar="ID",
|
| 268 |
+
)
|
| 269 |
+
@click.option(
|
| 270 |
+
"--service",
|
| 271 |
+
"service_id",
|
| 272 |
+
type=str,
|
| 273 |
+
required=True,
|
| 274 |
+
help="The InferenceService ID to submit.",
|
| 275 |
+
metavar="ID",
|
| 276 |
+
)
|
| 277 |
+
@click.option(
|
| 278 |
+
"--challenge",
|
| 279 |
+
"challenge_id",
|
| 280 |
+
type=str,
|
| 281 |
+
default="dc509a8c771b492b90c43012fde9a04f",
|
| 282 |
+
help="The Challenge ID to submit to.",
|
| 283 |
+
metavar="ID",
|
| 284 |
+
)
|
| 285 |
+
def submit(account: str, task_id: str, team_id: str, service_id: str, challenge_id: str) -> None:
|
| 286 |
+
dyffapi = Client()
|
| 287 |
+
|
| 288 |
+
challenge = dyffapi.challenges.get(challenge_id)
|
| 289 |
+
# print(challenge)
|
| 290 |
+
challengetask = challenge.tasks[task_id]
|
| 291 |
+
|
| 292 |
+
team = dyffapi.teams.get(team_id)
|
| 293 |
+
|
| 294 |
+
service = dyffapi.inferenceservices.get(service_id)
|
| 295 |
+
|
| 296 |
+
submission = dyffapi.challenges.submit(
|
| 297 |
+
challenge.id,
|
| 298 |
+
challengetask.id,
|
| 299 |
+
SubmissionCreateRequest(
|
| 300 |
+
account=account,
|
| 301 |
+
team=team.id,
|
| 302 |
+
submission=EntityIdentifier(kind="InferenceService", id=service.id),
|
| 303 |
+
),
|
| 304 |
+
)
|
| 305 |
+
click.echo(submission.model_dump_json(indent=2))
|
| 306 |
+
click.echo(f"submission.id: \"{submission.id}\"")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
if __name__ == "__main__":
|
| 310 |
+
cli(show_default=True)
|
delete_corrupted_files.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
def is_image_corrupted(path):
|
| 6 |
+
"""
|
| 7 |
+
Returns True if the image cannot be opened by cv2 or PIL.
|
| 8 |
+
"""
|
| 9 |
+
# ---- Try reading with OpenCV ----
|
| 10 |
+
img = cv2.imread(path)
|
| 11 |
+
if img is None:
|
| 12 |
+
return True # cv2 failed to decode
|
| 13 |
+
|
| 14 |
+
# ---- Try reading with PIL ----
|
| 15 |
+
try:
|
| 16 |
+
with Image.open(path) as im:
|
| 17 |
+
im.verify() # Validate without fully loading
|
| 18 |
+
except Exception:
|
| 19 |
+
return True # PIL failed to open / verify
|
| 20 |
+
|
| 21 |
+
return False # Image is valid
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def delete_corrupted_images(directory):
|
| 25 |
+
"""
|
| 26 |
+
Deletes all unreadable / corrupted images in the directory.
|
| 27 |
+
"""
|
| 28 |
+
valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
| 29 |
+
|
| 30 |
+
deleted_count = 0
|
| 31 |
+
|
| 32 |
+
for filename in os.listdir(directory):
|
| 33 |
+
path = os.path.join(directory, filename)
|
| 34 |
+
|
| 35 |
+
# Skip folders
|
| 36 |
+
if not os.path.isfile(path):
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
# Check valid image extensions
|
| 40 |
+
ext = os.path.splitext(filename)[1].lower()
|
| 41 |
+
if ext not in valid_extensions:
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
# Check corruption
|
| 45 |
+
if is_image_corrupted(path):
|
| 46 |
+
print(f"Deleting corrupted image: {filename}")
|
| 47 |
+
os.remove(path)
|
| 48 |
+
deleted_count += 1
|
| 49 |
+
|
| 50 |
+
print(f"\nDone. Deleted {deleted_count} corrupted images.")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# --------- RUN ----------
|
| 54 |
+
directory = "/media/NAS/DATASET/wacv_safe_challenge_2025_dataset/physical_fakes/PS-Battles/originals/" # <-- change this
|
| 55 |
+
delete_corrupted_images(directory)
|
main.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main FastAPI application entry point.
|
| 3 |
+
"""
|
| 4 |
+
from app.core.app import create_app
|
| 5 |
+
|
| 6 |
+
app = create_app()
|
makefile
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
VENV ?= venv
|
| 2 |
+
PYTHON ?= $(VENV)/bin/python3
|
| 3 |
+
UVICORN ?= $(VENV)/bin/uvicorn
|
| 4 |
+
DOCKER ?= docker
|
| 5 |
+
IMAGE ?= safe-challenge-2025/example-submission
|
| 6 |
+
GID ?= $(shell id -g)
|
| 7 |
+
UID ?= $(shell id -u)
|
| 8 |
+
|
| 9 |
+
.PHONY: setup
|
| 10 |
+
setup: $(VENV)/requirements.cpu.txt
|
| 11 |
+
|
| 12 |
+
.PHONY: download
|
| 13 |
+
download:
|
| 14 |
+
bash scripts/model_download.bash
|
| 15 |
+
|
| 16 |
+
.PHONY: serve
|
| 17 |
+
serve:
|
| 18 |
+
$(UVICORN) main:app
|
| 19 |
+
|
| 20 |
+
.PHONY: docker-build
|
| 21 |
+
docker-build:
|
| 22 |
+
docker build -t $(IMAGE) .
|
| 23 |
+
|
| 24 |
+
.PHONY: docker-run
|
| 25 |
+
docker-run:
|
| 26 |
+
docker run --rm -it -p 8000:8000 $(IMAGE)
|
| 27 |
+
|
| 28 |
+
.PHONY: compile
|
| 29 |
+
compile:
|
| 30 |
+
uv pip compile --python-version 3.12 --upgrade -o requirements.torch.cpu.txt.tmp requirements.torch.cpu.in
|
| 31 |
+
echo "--index-url https://download.pytorch.org/whl/cpu" > requirements.torch.cpu.txt
|
| 32 |
+
grep -e '^torch' requirements.torch.cpu.txt.tmp >> requirements.torch.cpu.txt
|
| 33 |
+
uv pip compile --python-version 3.12 --upgrade -o requirements.cpu.txt requirements.cpu.in
|
| 34 |
+
|
| 35 |
+
.PHONY: compile-gpu
|
| 36 |
+
compile-gpu:
|
| 37 |
+
uv pip compile --python-version 3.12 --upgrade -o requirements.torch.gpu.txt requirements.torch.gpu.in
|
| 38 |
+
|
| 39 |
+
.PHONY: docker-build-gpu
|
| 40 |
+
docker-build-gpu:
|
| 41 |
+
docker build -t $(IMAGE)-gpu -f Dockerfile.gpu .
|
| 42 |
+
|
| 43 |
+
requirements.cpu.txt: requirements.in requirements.torch.cpu.txt | $(VENV)
|
| 44 |
+
uv pip compile --python-version 3.12 --upgrade -o requirements.cpu.txt requirements.cpu.in
|
| 45 |
+
|
| 46 |
+
requirements.torch.cpu.txt: requirements.torch.cpu.in | $(VENV)
|
| 47 |
+
uv pip compile --python-version 3.12 --upgrade -o requirements.torch.cpu.txt.tmp requirements.torch.cpu.in
|
| 48 |
+
echo "--index-url https://download.pytorch.org/whl/cpu" > requirements.torch.cpu.txt
|
| 49 |
+
cat requirements.torch.cpu.txt.tmp | grep '^torch' >> requirements.torch.cpu.txt
|
| 50 |
+
|
| 51 |
+
$(VENV)/requirements.cpu.txt: requirements.cpu.txt | $(VENV)
|
| 52 |
+
VIRTUAL_ENV=$(VENV) uv pip install -r requirements.cpu.txt
|
| 53 |
+
VIRTUAL_ENV=$(VENV) uv pip install -r requirements.torch.cpu.txt
|
| 54 |
+
cp -f requirements.cpu.txt $(VENV)/requirements.cpu.txt
|
| 55 |
+
cp -f requirements.torch.cpu.txt $(VENV)/requirements.torch.cpu.txt
|
| 56 |
+
|
| 57 |
+
$(VENV):
|
| 58 |
+
uv venv $(VENV)
|
mask.png
ADDED
|
prompt.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
curl -X POST http://localhost:8000/predict \
|
| 4 |
+
-H "Content-Type: application/json" \
|
| 5 |
+
-d @"$1"
|
| 6 |
+
|
requirements copy.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
numpy==1.26.1
|
| 3 |
+
pandas==2.2.1
|
| 4 |
+
scipy==1.12.0
|
| 5 |
+
|
| 6 |
+
torch==2.4.0
|
| 7 |
+
torchvision==0.19.0
|
| 8 |
+
torchaudio==2.4.0
|
| 9 |
+
|
| 10 |
+
opencv-contrib-python==4.11.0.86
|
| 11 |
+
av==14.0.1
|
| 12 |
+
decord
|
| 13 |
+
kornia==0.7.4
|
| 14 |
+
timm
|
| 15 |
+
einops
|
| 16 |
+
beartype
|
| 17 |
+
lpips==0.1.4
|
| 18 |
+
imageio
|
| 19 |
+
|
| 20 |
+
audiomentations==0.36.0
|
| 21 |
+
librosa==0.10.2.post1
|
| 22 |
+
soundfile==0.12.1
|
| 23 |
+
|
| 24 |
+
transformers==4.49.0
|
| 25 |
+
datasets==3.5.1
|
| 26 |
+
tqdm
|
requirements.cpu.in
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
-r requirements.in
|
requirements.cpu.txt
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile --python-version 3.12 -o requirements.cpu.txt requirements.cpu.in
|
| 3 |
+
absl-py==2.3.1
|
| 4 |
+
# via dyff-client
|
| 5 |
+
annotated-doc==0.0.4
|
| 6 |
+
# via fastapi
|
| 7 |
+
annotated-types==0.7.0
|
| 8 |
+
# via pydantic
|
| 9 |
+
anyio==4.12.0
|
| 10 |
+
# via
|
| 11 |
+
# httpx
|
| 12 |
+
# jupyter-server
|
| 13 |
+
# starlette
|
| 14 |
+
# watchfiles
|
| 15 |
+
argon2-cffi==25.1.0
|
| 16 |
+
# via jupyter-server
|
| 17 |
+
argon2-cffi-bindings==25.1.0
|
| 18 |
+
# via argon2-cffi
|
| 19 |
+
arrow==1.4.0
|
| 20 |
+
# via isoduration
|
| 21 |
+
asttokens==3.0.1
|
| 22 |
+
# via stack-data
|
| 23 |
+
async-lru==2.0.5
|
| 24 |
+
# via jupyterlab
|
| 25 |
+
attrs==25.4.0
|
| 26 |
+
# via
|
| 27 |
+
# jsonschema
|
| 28 |
+
# referencing
|
| 29 |
+
azure-core==1.36.0
|
| 30 |
+
# via dyff-client
|
| 31 |
+
babel==2.17.0
|
| 32 |
+
# via jupyterlab-server
|
| 33 |
+
beartype==0.22.8
|
| 34 |
+
# via -r requirements.in
|
| 35 |
+
beautifulsoup4==4.14.3
|
| 36 |
+
# via
|
| 37 |
+
# dyff-audit
|
| 38 |
+
# nbconvert
|
| 39 |
+
bleach==6.3.0
|
| 40 |
+
# via nbconvert
|
| 41 |
+
canonicaljson==2.0.0
|
| 42 |
+
# via dyff-schema
|
| 43 |
+
certifi==2025.11.12
|
| 44 |
+
# via
|
| 45 |
+
# httpcore
|
| 46 |
+
# httpx
|
| 47 |
+
# requests
|
| 48 |
+
cffi==2.0.0
|
| 49 |
+
# via argon2-cffi-bindings
|
| 50 |
+
charset-normalizer==3.4.4
|
| 51 |
+
# via requests
|
| 52 |
+
click==8.3.1
|
| 53 |
+
# via
|
| 54 |
+
# -r requirements.in
|
| 55 |
+
# uvicorn
|
| 56 |
+
comm==0.2.3
|
| 57 |
+
# via ipykernel
|
| 58 |
+
contourpy==1.3.3
|
| 59 |
+
# via matplotlib
|
| 60 |
+
cycler==0.12.1
|
| 61 |
+
# via matplotlib
|
| 62 |
+
debugpy==1.8.17
|
| 63 |
+
# via ipykernel
|
| 64 |
+
decorator==5.2.1
|
| 65 |
+
# via ipython
|
| 66 |
+
defusedxml==0.7.1
|
| 67 |
+
# via nbconvert
|
| 68 |
+
dnspython==2.8.0
|
| 69 |
+
# via email-validator
|
| 70 |
+
dyff-audit==0.16.7
|
| 71 |
+
# via -r requirements.in
|
| 72 |
+
dyff-client==0.23.9
|
| 73 |
+
# via
|
| 74 |
+
# -r requirements.in
|
| 75 |
+
# dyff-audit
|
| 76 |
+
dyff-schema==0.40.1
|
| 77 |
+
# via
|
| 78 |
+
# -r requirements.in
|
| 79 |
+
# dyff-audit
|
| 80 |
+
# dyff-client
|
| 81 |
+
einops==0.8.1
|
| 82 |
+
# via -r requirements.in
|
| 83 |
+
email-validator==2.3.0
|
| 84 |
+
# via dyff-schema
|
| 85 |
+
executing==2.2.1
|
| 86 |
+
# via stack-data
|
| 87 |
+
fastapi==0.124.1
|
| 88 |
+
# via -r requirements.in
|
| 89 |
+
fastjsonschema==2.21.2
|
| 90 |
+
# via nbformat
|
| 91 |
+
filelock==3.20.0
|
| 92 |
+
# via
|
| 93 |
+
# huggingface-hub
|
| 94 |
+
# torch
|
| 95 |
+
# transformers
|
| 96 |
+
fonttools==4.61.0
|
| 97 |
+
# via matplotlib
|
| 98 |
+
fqdn==1.5.1
|
| 99 |
+
# via jsonschema
|
| 100 |
+
fsspec==2025.12.0
|
| 101 |
+
# via
|
| 102 |
+
# huggingface-hub
|
| 103 |
+
# torch
|
| 104 |
+
google-i18n-address==3.1.1
|
| 105 |
+
# via dyff-schema
|
| 106 |
+
h11==0.16.0
|
| 107 |
+
# via
|
| 108 |
+
# httpcore
|
| 109 |
+
# uvicorn
|
| 110 |
+
hf-xet==1.2.0
|
| 111 |
+
# via huggingface-hub
|
| 112 |
+
httpcore==1.0.9
|
| 113 |
+
# via httpx
|
| 114 |
+
httptools==0.7.1
|
| 115 |
+
# via uvicorn
|
| 116 |
+
httpx==0.28.1
|
| 117 |
+
# via
|
| 118 |
+
# dyff-client
|
| 119 |
+
# jupyterlab
|
| 120 |
+
huggingface-hub==0.36.0
|
| 121 |
+
# via
|
| 122 |
+
# timm
|
| 123 |
+
# tokenizers
|
| 124 |
+
# transformers
|
| 125 |
+
hypothesis==6.148.7
|
| 126 |
+
# via
|
| 127 |
+
# dyff-schema
|
| 128 |
+
# hypothesis-jsonschema
|
| 129 |
+
hypothesis-jsonschema==0.23.1
|
| 130 |
+
# via dyff-schema
|
| 131 |
+
idna==3.11
|
| 132 |
+
# via
|
| 133 |
+
# anyio
|
| 134 |
+
# email-validator
|
| 135 |
+
# httpx
|
| 136 |
+
# jsonschema
|
| 137 |
+
# requests
|
| 138 |
+
ipykernel==7.1.0
|
| 139 |
+
# via jupyterlab
|
| 140 |
+
ipython==9.8.0
|
| 141 |
+
# via ipykernel
|
| 142 |
+
ipython-pygments-lexers==1.1.1
|
| 143 |
+
# via ipython
|
| 144 |
+
isodate==0.7.2
|
| 145 |
+
# via dyff-client
|
| 146 |
+
isoduration==20.11.0
|
| 147 |
+
# via jsonschema
|
| 148 |
+
jedi==0.19.2
|
| 149 |
+
# via ipython
|
| 150 |
+
jinja2==3.1.6
|
| 151 |
+
# via
|
| 152 |
+
# jupyter-server
|
| 153 |
+
# jupyterlab
|
| 154 |
+
# jupyterlab-server
|
| 155 |
+
# nbconvert
|
| 156 |
+
# torch
|
| 157 |
+
json5==0.12.1
|
| 158 |
+
# via jupyterlab-server
|
| 159 |
+
jsonpath-ng==1.7.0
|
| 160 |
+
# via
|
| 161 |
+
# dyff-client
|
| 162 |
+
# dyff-schema
|
| 163 |
+
jsonpointer==3.0.0
|
| 164 |
+
# via jsonschema
|
| 165 |
+
jsonschema==4.25.1
|
| 166 |
+
# via
|
| 167 |
+
# hypothesis-jsonschema
|
| 168 |
+
# jupyter-events
|
| 169 |
+
# jupyterlab-server
|
| 170 |
+
# nbformat
|
| 171 |
+
jsonschema-specifications==2025.9.1
|
| 172 |
+
# via jsonschema
|
| 173 |
+
jupyter-client==8.7.0
|
| 174 |
+
# via
|
| 175 |
+
# ipykernel
|
| 176 |
+
# jupyter-server
|
| 177 |
+
# nbclient
|
| 178 |
+
jupyter-core==5.9.1
|
| 179 |
+
# via
|
| 180 |
+
# ipykernel
|
| 181 |
+
# jupyter-client
|
| 182 |
+
# jupyter-server
|
| 183 |
+
# jupyterlab
|
| 184 |
+
# nbclient
|
| 185 |
+
# nbconvert
|
| 186 |
+
# nbformat
|
| 187 |
+
jupyter-events==0.12.0
|
| 188 |
+
# via jupyter-server
|
| 189 |
+
jupyter-lsp==2.3.0
|
| 190 |
+
# via jupyterlab
|
| 191 |
+
jupyter-server==2.17.0
|
| 192 |
+
# via
|
| 193 |
+
# jupyter-lsp
|
| 194 |
+
# jupyterlab
|
| 195 |
+
# jupyterlab-server
|
| 196 |
+
# notebook
|
| 197 |
+
# notebook-shim
|
| 198 |
+
jupyter-server-terminals==0.5.3
|
| 199 |
+
# via jupyter-server
|
| 200 |
+
jupyterlab==4.5.0
|
| 201 |
+
# via notebook
|
| 202 |
+
jupyterlab-pygments==0.3.0
|
| 203 |
+
# via nbconvert
|
| 204 |
+
jupyterlab-server==2.28.0
|
| 205 |
+
# via
|
| 206 |
+
# jupyterlab
|
| 207 |
+
# notebook
|
| 208 |
+
kiwisolver==1.4.9
|
| 209 |
+
# via matplotlib
|
| 210 |
+
lark==1.3.1
|
| 211 |
+
# via rfc3987-syntax
|
| 212 |
+
lpips==0.1.4
|
| 213 |
+
# via -r requirements.in
|
| 214 |
+
lxml==6.0.2
|
| 215 |
+
# via dyff-audit
|
| 216 |
+
markupsafe==3.0.3
|
| 217 |
+
# via
|
| 218 |
+
# jinja2
|
| 219 |
+
# nbconvert
|
| 220 |
+
matplotlib==3.10.7
|
| 221 |
+
# via -r requirements.in
|
| 222 |
+
matplotlib-inline==0.2.1
|
| 223 |
+
# via
|
| 224 |
+
# ipykernel
|
| 225 |
+
# ipython
|
| 226 |
+
mistune==3.1.4
|
| 227 |
+
# via nbconvert
|
| 228 |
+
mpmath==1.3.0
|
| 229 |
+
# via sympy
|
| 230 |
+
nbclient==0.10.2
|
| 231 |
+
# via nbconvert
|
| 232 |
+
nbconvert==7.16.6
|
| 233 |
+
# via
|
| 234 |
+
# dyff-audit
|
| 235 |
+
# jupyter-server
|
| 236 |
+
nbformat==5.10.4
|
| 237 |
+
# via
|
| 238 |
+
# dyff-audit
|
| 239 |
+
# jupyter-server
|
| 240 |
+
# nbclient
|
| 241 |
+
# nbconvert
|
| 242 |
+
nest-asyncio==1.6.0
|
| 243 |
+
# via ipykernel
|
| 244 |
+
networkx==3.6.1
|
| 245 |
+
# via torch
|
| 246 |
+
notebook==7.5.0
|
| 247 |
+
# via dyff-audit
|
| 248 |
+
notebook-shim==0.2.4
|
| 249 |
+
# via
|
| 250 |
+
# jupyterlab
|
| 251 |
+
# notebook
|
| 252 |
+
numpy==1.26.4
|
| 253 |
+
# via
|
| 254 |
+
# -r requirements.in
|
| 255 |
+
# contourpy
|
| 256 |
+
# dyff-audit
|
| 257 |
+
# dyff-client
|
| 258 |
+
# dyff-schema
|
| 259 |
+
# lpips
|
| 260 |
+
# matplotlib
|
| 261 |
+
# opencv-python
|
| 262 |
+
# pandas
|
| 263 |
+
# scipy
|
| 264 |
+
# torchvision
|
| 265 |
+
# transformers
|
| 266 |
+
nvidia-cublas-cu12==12.8.4.1
|
| 267 |
+
# via
|
| 268 |
+
# nvidia-cudnn-cu12
|
| 269 |
+
# nvidia-cusolver-cu12
|
| 270 |
+
# torch
|
| 271 |
+
nvidia-cuda-cupti-cu12==12.8.90
|
| 272 |
+
# via torch
|
| 273 |
+
nvidia-cuda-nvrtc-cu12==12.8.93
|
| 274 |
+
# via torch
|
| 275 |
+
nvidia-cuda-runtime-cu12==12.8.90
|
| 276 |
+
# via torch
|
| 277 |
+
nvidia-cudnn-cu12==9.10.2.21
|
| 278 |
+
# via torch
|
| 279 |
+
nvidia-cufft-cu12==11.3.3.83
|
| 280 |
+
# via torch
|
| 281 |
+
nvidia-cufile-cu12==1.13.1.3
|
| 282 |
+
# via torch
|
| 283 |
+
nvidia-curand-cu12==10.3.9.90
|
| 284 |
+
# via torch
|
| 285 |
+
nvidia-cusolver-cu12==11.7.3.90
|
| 286 |
+
# via torch
|
| 287 |
+
nvidia-cusparse-cu12==12.5.8.93
|
| 288 |
+
# via
|
| 289 |
+
# nvidia-cusolver-cu12
|
| 290 |
+
# torch
|
| 291 |
+
nvidia-cusparselt-cu12==0.7.1
|
| 292 |
+
# via torch
|
| 293 |
+
nvidia-nccl-cu12==2.27.5
|
| 294 |
+
# via torch
|
| 295 |
+
nvidia-nvjitlink-cu12==12.8.93
|
| 296 |
+
# via
|
| 297 |
+
# nvidia-cufft-cu12
|
| 298 |
+
# nvidia-cusolver-cu12
|
| 299 |
+
# nvidia-cusparse-cu12
|
| 300 |
+
# torch
|
| 301 |
+
nvidia-nvshmem-cu12==3.3.20
|
| 302 |
+
# via torch
|
| 303 |
+
nvidia-nvtx-cu12==12.8.90
|
| 304 |
+
# via torch
|
| 305 |
+
opencv-python==4.11.0.86
|
| 306 |
+
# via -r requirements.in
|
| 307 |
+
packaging==25.0
|
| 308 |
+
# via
|
| 309 |
+
# huggingface-hub
|
| 310 |
+
# ipykernel
|
| 311 |
+
# jupyter-events
|
| 312 |
+
# jupyter-server
|
| 313 |
+
# jupyterlab
|
| 314 |
+
# jupyterlab-server
|
| 315 |
+
# matplotlib
|
| 316 |
+
# nbconvert
|
| 317 |
+
# transformers
|
| 318 |
+
pandas==2.3.3
|
| 319 |
+
# via
|
| 320 |
+
# -r requirements.in
|
| 321 |
+
# dyff-audit
|
| 322 |
+
# dyff-client
|
| 323 |
+
pandocfilters==1.5.1
|
| 324 |
+
# via nbconvert
|
| 325 |
+
parso==0.8.5
|
| 326 |
+
# via jedi
|
| 327 |
+
pexpect==4.9.0
|
| 328 |
+
# via ipython
|
| 329 |
+
pillow==12.0.0
|
| 330 |
+
# via
|
| 331 |
+
# -r requirements.in
|
| 332 |
+
# matplotlib
|
| 333 |
+
# torchvision
|
| 334 |
+
platformdirs==4.5.1
|
| 335 |
+
# via jupyter-core
|
| 336 |
+
ply==3.11
|
| 337 |
+
# via jsonpath-ng
|
| 338 |
+
prometheus-client==0.23.1
|
| 339 |
+
# via jupyter-server
|
| 340 |
+
prompt-toolkit==3.0.52
|
| 341 |
+
# via ipython
|
| 342 |
+
psutil==7.1.3
|
| 343 |
+
# via ipykernel
|
| 344 |
+
ptyprocess==0.7.0
|
| 345 |
+
# via
|
| 346 |
+
# pexpect
|
| 347 |
+
# terminado
|
| 348 |
+
pure-eval==0.2.3
|
| 349 |
+
# via stack-data
|
| 350 |
+
pyarrow==22.0.0
|
| 351 |
+
# via
|
| 352 |
+
# -r requirements.in
|
| 353 |
+
# dyff-audit
|
| 354 |
+
# dyff-client
|
| 355 |
+
# dyff-schema
|
| 356 |
+
pycparser==2.23
|
| 357 |
+
# via cffi
|
| 358 |
+
pydantic==2.12.5
|
| 359 |
+
# via
|
| 360 |
+
# -r requirements.in
|
| 361 |
+
# dyff-audit
|
| 362 |
+
# dyff-client
|
| 363 |
+
# dyff-schema
|
| 364 |
+
# fastapi
|
| 365 |
+
# pydantic-settings
|
| 366 |
+
pydantic-core==2.41.5
|
| 367 |
+
# via pydantic
|
| 368 |
+
pydantic-settings==2.12.0
|
| 369 |
+
# via -r requirements.in
|
| 370 |
+
pygments==2.19.2
|
| 371 |
+
# via
|
| 372 |
+
# ipython
|
| 373 |
+
# ipython-pygments-lexers
|
| 374 |
+
# nbconvert
|
| 375 |
+
pyparsing==3.2.5
|
| 376 |
+
# via matplotlib
|
| 377 |
+
python-dateutil==2.9.0.post0
|
| 378 |
+
# via
|
| 379 |
+
# arrow
|
| 380 |
+
# jupyter-client
|
| 381 |
+
# matplotlib
|
| 382 |
+
# pandas
|
| 383 |
+
python-dotenv==1.2.1
|
| 384 |
+
# via
|
| 385 |
+
# -r requirements.in
|
| 386 |
+
# pydantic-settings
|
| 387 |
+
# uvicorn
|
| 388 |
+
python-json-logger==4.0.0
|
| 389 |
+
# via jupyter-events
|
| 390 |
+
python-multipart==0.0.20
|
| 391 |
+
# via -r requirements.in
|
| 392 |
+
pytz==2025.2
|
| 393 |
+
# via pandas
|
| 394 |
+
pyyaml==6.0.3
|
| 395 |
+
# via
|
| 396 |
+
# huggingface-hub
|
| 397 |
+
# jupyter-events
|
| 398 |
+
# timm
|
| 399 |
+
# transformers
|
| 400 |
+
# uvicorn
|
| 401 |
+
pyzmq==27.1.0
|
| 402 |
+
# via
|
| 403 |
+
# ipykernel
|
| 404 |
+
# jupyter-client
|
| 405 |
+
# jupyter-server
|
| 406 |
+
referencing==0.37.0
|
| 407 |
+
# via
|
| 408 |
+
# jsonschema
|
| 409 |
+
# jsonschema-specifications
|
| 410 |
+
# jupyter-events
|
| 411 |
+
regex==2025.11.3
|
| 412 |
+
# via transformers
|
| 413 |
+
requests==2.32.5
|
| 414 |
+
# via
|
| 415 |
+
# -r requirements.in
|
| 416 |
+
# azure-core
|
| 417 |
+
# google-i18n-address
|
| 418 |
+
# huggingface-hub
|
| 419 |
+
# jupyterlab-server
|
| 420 |
+
# transformers
|
| 421 |
+
rfc3339-validator==0.1.4
|
| 422 |
+
# via
|
| 423 |
+
# jsonschema
|
| 424 |
+
# jupyter-events
|
| 425 |
+
rfc3986-validator==0.1.1
|
| 426 |
+
# via
|
| 427 |
+
# jsonschema
|
| 428 |
+
# jupyter-events
|
| 429 |
+
rfc3987-syntax==1.1.0
|
| 430 |
+
# via jsonschema
|
| 431 |
+
rpds-py==0.30.0
|
| 432 |
+
# via
|
| 433 |
+
# jsonschema
|
| 434 |
+
# referencing
|
| 435 |
+
ruamel-yaml==0.18.16
|
| 436 |
+
# via dyff-audit
|
| 437 |
+
ruamel-yaml-clib==0.2.15
|
| 438 |
+
# via ruamel-yaml
|
| 439 |
+
safetensors==0.7.0
|
| 440 |
+
# via
|
| 441 |
+
# timm
|
| 442 |
+
# transformers
|
| 443 |
+
scipy==1.16.3
|
| 444 |
+
# via lpips
|
| 445 |
+
send2trash==1.8.3
|
| 446 |
+
# via jupyter-server
|
| 447 |
+
setuptools==80.9.0
|
| 448 |
+
# via
|
| 449 |
+
# jupyterlab
|
| 450 |
+
# torch
|
| 451 |
+
six==1.17.0
|
| 452 |
+
# via
|
| 453 |
+
# python-dateutil
|
| 454 |
+
# rfc3339-validator
|
| 455 |
+
sortedcontainers==2.4.0
|
| 456 |
+
# via hypothesis
|
| 457 |
+
soupsieve==2.8
|
| 458 |
+
# via beautifulsoup4
|
| 459 |
+
stack-data==0.6.3
|
| 460 |
+
# via ipython
|
| 461 |
+
starlette==0.50.0
|
| 462 |
+
# via fastapi
|
| 463 |
+
sympy==1.14.0
|
| 464 |
+
# via torch
|
| 465 |
+
terminado==0.18.1
|
| 466 |
+
# via
|
| 467 |
+
# jupyter-server
|
| 468 |
+
# jupyter-server-terminals
|
| 469 |
+
timm==1.0.22
|
| 470 |
+
# via -r requirements.in
|
| 471 |
+
tinycss2==1.4.0
|
| 472 |
+
# via bleach
|
| 473 |
+
tokenizers==0.22.1
|
| 474 |
+
# via transformers
|
| 475 |
+
torch==2.9.1
|
| 476 |
+
# via
|
| 477 |
+
# lpips
|
| 478 |
+
# timm
|
| 479 |
+
# torchvision
|
| 480 |
+
torchvision==0.24.1
|
| 481 |
+
# via
|
| 482 |
+
# lpips
|
| 483 |
+
# timm
|
| 484 |
+
tornado==6.5.2
|
| 485 |
+
# via
|
| 486 |
+
# ipykernel
|
| 487 |
+
# jupyter-client
|
| 488 |
+
# jupyter-server
|
| 489 |
+
# jupyterlab
|
| 490 |
+
# notebook
|
| 491 |
+
# terminado
|
| 492 |
+
tqdm==4.67.1
|
| 493 |
+
# via
|
| 494 |
+
# dyff-client
|
| 495 |
+
# huggingface-hub
|
| 496 |
+
# lpips
|
| 497 |
+
# transformers
|
| 498 |
+
traitlets==5.14.3
|
| 499 |
+
# via
|
| 500 |
+
# ipykernel
|
| 501 |
+
# ipython
|
| 502 |
+
# jupyter-client
|
| 503 |
+
# jupyter-core
|
| 504 |
+
# jupyter-events
|
| 505 |
+
# jupyter-server
|
| 506 |
+
# jupyterlab
|
| 507 |
+
# matplotlib-inline
|
| 508 |
+
# nbclient
|
| 509 |
+
# nbconvert
|
| 510 |
+
# nbformat
|
| 511 |
+
transformers==4.57.3
|
| 512 |
+
# via -r requirements.in
|
| 513 |
+
triton
|
| 514 |
+
# via torch
|
| 515 |
+
typing
|
| 516 |
+
# via -r requirements.in
|
| 517 |
+
typing-extensions
|
| 518 |
+
# via
|
| 519 |
+
# anyio
|
| 520 |
+
# azure-core
|
| 521 |
+
# beautifulsoup4
|
| 522 |
+
# fastapi
|
| 523 |
+
# huggingface-hub
|
| 524 |
+
# pydantic
|
| 525 |
+
# pydantic-core
|
| 526 |
+
# referencing
|
| 527 |
+
# starlette
|
| 528 |
+
# torch
|
| 529 |
+
# typing-inspection
|
| 530 |
+
typing-inspection
|
| 531 |
+
# via
|
| 532 |
+
# pydantic
|
| 533 |
+
# pydantic-settings
|
| 534 |
+
tzdata
|
| 535 |
+
# via
|
| 536 |
+
# arrow
|
| 537 |
+
# pandas
|
| 538 |
+
uri-template
|
| 539 |
+
# via jsonschema
|
| 540 |
+
urllib3
|
| 541 |
+
# via requests
|
| 542 |
+
uvicorn
|
| 543 |
+
# via -r requirements.in
|
| 544 |
+
uvloop
|
| 545 |
+
# via uvicorn
|
| 546 |
+
watchfiles
|
| 547 |
+
# via uvicorn
|
| 548 |
+
wcwidth==0.2.14
|
| 549 |
+
# via prompt-toolkit
|
| 550 |
+
webcolors==25.10.0
|
| 551 |
+
# via jsonschema
|
| 552 |
+
webencodings==0.5.1
|
| 553 |
+
# via
|
| 554 |
+
# bleach
|
| 555 |
+
# tinycss2
|
| 556 |
+
websocket-client==1.9.0
|
| 557 |
+
# via jupyter-server
|
| 558 |
+
websockets==15.0.1
|
| 559 |
+
# via
|
| 560 |
+
# dyff-client
|
| 561 |
+
# uvicorn
|
requirements.in
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dyff-audit
|
| 2 |
+
dyff-client
|
| 3 |
+
dyff-schema>=0.39.1
|
| 4 |
+
|
| 5 |
+
click
|
| 6 |
+
|
| 7 |
+
# Web framework
|
| 8 |
+
fastapi
|
| 9 |
+
uvicorn[standard]
|
| 10 |
+
|
| 11 |
+
# Configuration management
|
| 12 |
+
pydantic>=2.0.0,<3.0.0
|
| 13 |
+
pydantic-settings
|
| 14 |
+
python-dotenv
|
| 15 |
+
|
| 16 |
+
# File upload handling
|
| 17 |
+
python-multipart
|
| 18 |
+
|
| 19 |
+
# ML/AI dependencies (newly added)
|
| 20 |
+
transformers>=4.35.0
|
| 21 |
+
pillow>=10.0.0
|
| 22 |
+
|
| 23 |
+
# Dataset generation and testing
|
| 24 |
+
pyarrow>=14.0.0
|
| 25 |
+
numpy
|
| 26 |
+
pandas
|
| 27 |
+
requests>=2.25.0
|
| 28 |
+
opencv-python
|
| 29 |
+
matplotlib
|
| 30 |
+
typing
|
| 31 |
+
einops
|
| 32 |
+
timm
|
| 33 |
+
beartype
|
| 34 |
+
lpips
|
requirements.torch.cpu.in
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--index-url https://download.pytorch.org/whl/cpu
|
| 2 |
+
torch==2.9.1+cpu
|
| 3 |
+
torchvision==0.24.1+cpu
|
requirements.torch.cpu.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#--index-url https://download.pytorch.org/whl/cu126
|
| 2 |
+
#torch==2.9.1
|
| 3 |
+
#torchvision==0.24.1
|
| 4 |
+
#torchaudio==2.9.0
|
| 5 |
+
|
| 6 |
+
--index-url https://download.pytorch.org/whl/nightly/cu121
|
| 7 |
+
torch
|
| 8 |
+
torchvision
|
| 9 |
+
torchaudio
|
requirements.torch.gpu.in
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
requirements.torch.gpu.txt
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile --python-version 3.12 -o requirements.torch.gpu.txt requirements.torch.gpu.in
|
| 3 |
+
filelock==3.20.0
|
| 4 |
+
# via torch
|
| 5 |
+
fsspec==2025.12.0
|
| 6 |
+
# via torch
|
| 7 |
+
jinja2==3.1.6
|
| 8 |
+
# via torch
|
| 9 |
+
markupsafe==3.0.3
|
| 10 |
+
# via jinja2
|
| 11 |
+
mpmath==1.3.0
|
| 12 |
+
# via sympy
|
| 13 |
+
networkx==3.6.1
|
| 14 |
+
# via torch
|
| 15 |
+
numpy==2.3.5
|
| 16 |
+
# via torchvision
|
| 17 |
+
nvidia-cublas-cu12==12.8.4.1
|
| 18 |
+
# via
|
| 19 |
+
# nvidia-cudnn-cu12
|
| 20 |
+
# nvidia-cusolver-cu12
|
| 21 |
+
# torch
|
| 22 |
+
nvidia-cuda-cupti-cu12==12.8.90
|
| 23 |
+
# via torch
|
| 24 |
+
nvidia-cuda-nvrtc-cu12==12.8.93
|
| 25 |
+
# via torch
|
| 26 |
+
nvidia-cuda-runtime-cu12==12.8.90
|
| 27 |
+
# via torch
|
| 28 |
+
nvidia-cudnn-cu12==9.10.2.21
|
| 29 |
+
# via torch
|
| 30 |
+
nvidia-cufft-cu12==11.3.3.83
|
| 31 |
+
# via torch
|
| 32 |
+
nvidia-cufile-cu12==1.13.1.3
|
| 33 |
+
# via torch
|
| 34 |
+
nvidia-curand-cu12==10.3.9.90
|
| 35 |
+
# via torch
|
| 36 |
+
nvidia-cusolver-cu12==11.7.3.90
|
| 37 |
+
# via torch
|
| 38 |
+
nvidia-cusparse-cu12==12.5.8.93
|
| 39 |
+
# via
|
| 40 |
+
# nvidia-cusolver-cu12
|
| 41 |
+
# torch
|
| 42 |
+
nvidia-cusparselt-cu12==0.7.1
|
| 43 |
+
# via torch
|
| 44 |
+
nvidia-nccl-cu12==2.27.5
|
| 45 |
+
# via torch
|
| 46 |
+
nvidia-nvjitlink-cu12==12.8.93
|
| 47 |
+
# via
|
| 48 |
+
# nvidia-cufft-cu12
|
| 49 |
+
# nvidia-cusolver-cu12
|
| 50 |
+
# nvidia-cusparse-cu12
|
| 51 |
+
# torch
|
| 52 |
+
nvidia-nvshmem-cu12==3.3.20
|
| 53 |
+
# via torch
|
| 54 |
+
nvidia-nvtx-cu12==12.8.90
|
| 55 |
+
# via torch
|
| 56 |
+
pillow==12.0.0
|
| 57 |
+
# via torchvision
|
| 58 |
+
setuptools==80.9.0
|
| 59 |
+
# via torch
|
| 60 |
+
sympy==1.14.0
|
| 61 |
+
# via torch
|
| 62 |
+
torch==2.9.1
|
| 63 |
+
# via
|
| 64 |
+
# -r requirements.torch.gpu.in
|
| 65 |
+
# torchvision
|
| 66 |
+
torchvision==0.24.1
|
| 67 |
+
# via -r requirements.torch.gpu.in
|
| 68 |
+
triton==3.5.1
|
| 69 |
+
# via torch
|
| 70 |
+
typing-extensions==4.15.0
|
| 71 |
+
# via torch
|
requirements.txt
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile --python-version 3.12 -o requirements.txt requirements.in
|
| 3 |
+
absl-py==2.3.1
|
| 4 |
+
# via dyff-client
|
| 5 |
+
annotated-types==0.7.0
|
| 6 |
+
# via pydantic
|
| 7 |
+
anyio==3.7.1
|
| 8 |
+
# via
|
| 9 |
+
# fastapi
|
| 10 |
+
# httpx
|
| 11 |
+
# jupyter-server
|
| 12 |
+
# starlette
|
| 13 |
+
# watchfiles
|
| 14 |
+
argon2-cffi==25.1.0
|
| 15 |
+
# via jupyter-server
|
| 16 |
+
argon2-cffi-bindings==25.1.0
|
| 17 |
+
# via argon2-cffi
|
| 18 |
+
arrow==1.4.0
|
| 19 |
+
# via isoduration
|
| 20 |
+
asttokens==3.0.1
|
| 21 |
+
# via stack-data
|
| 22 |
+
async-lru==2.0.5
|
| 23 |
+
# via jupyterlab
|
| 24 |
+
attrs==25.4.0
|
| 25 |
+
# via
|
| 26 |
+
# jsonschema
|
| 27 |
+
# referencing
|
| 28 |
+
azure-core==1.36.0
|
| 29 |
+
# via dyff-client
|
| 30 |
+
babel==2.17.0
|
| 31 |
+
# via jupyterlab-server
|
| 32 |
+
beautifulsoup4==4.14.2
|
| 33 |
+
# via
|
| 34 |
+
# dyff-audit
|
| 35 |
+
# nbconvert
|
| 36 |
+
bleach==6.3.0
|
| 37 |
+
# via nbconvert
|
| 38 |
+
canonicaljson==2.0.0
|
| 39 |
+
# via dyff-schema
|
| 40 |
+
certifi==2025.11.12
|
| 41 |
+
# via
|
| 42 |
+
# httpcore
|
| 43 |
+
# httpx
|
| 44 |
+
# requests
|
| 45 |
+
cffi==2.0.0
|
| 46 |
+
# via argon2-cffi-bindings
|
| 47 |
+
charset-normalizer==3.4.4
|
| 48 |
+
# via requests
|
| 49 |
+
click==8.3.1
|
| 50 |
+
# via
|
| 51 |
+
# -r requirements.in
|
| 52 |
+
# uvicorn
|
| 53 |
+
comm==0.2.3
|
| 54 |
+
# via ipykernel
|
| 55 |
+
debugpy==1.8.17
|
| 56 |
+
# via ipykernel
|
| 57 |
+
decorator==5.2.1
|
| 58 |
+
# via ipython
|
| 59 |
+
defusedxml==0.7.1
|
| 60 |
+
# via nbconvert
|
| 61 |
+
dnspython==2.8.0
|
| 62 |
+
# via email-validator
|
| 63 |
+
dyff-audit==0.16.1
|
| 64 |
+
# via -r requirements.in
|
| 65 |
+
dyff-client==0.23.5
|
| 66 |
+
# via
|
| 67 |
+
# -r requirements.in
|
| 68 |
+
# dyff-audit
|
| 69 |
+
dyff-schema==0.39.1
|
| 70 |
+
# via
|
| 71 |
+
# -r requirements.in
|
| 72 |
+
# dyff-audit
|
| 73 |
+
# dyff-client
|
| 74 |
+
email-validator==2.3.0
|
| 75 |
+
# via dyff-schema
|
| 76 |
+
executing==2.2.1
|
| 77 |
+
# via stack-data
|
| 78 |
+
fastapi==0.104.1
|
| 79 |
+
# via -r requirements.in
|
| 80 |
+
fastjsonschema==2.21.2
|
| 81 |
+
# via nbformat
|
| 82 |
+
filelock==3.20.0
|
| 83 |
+
# via
|
| 84 |
+
# huggingface-hub
|
| 85 |
+
# torch
|
| 86 |
+
# transformers
|
| 87 |
+
fqdn==1.5.1
|
| 88 |
+
# via jsonschema
|
| 89 |
+
fsspec==2025.10.0
|
| 90 |
+
# via
|
| 91 |
+
# huggingface-hub
|
| 92 |
+
# torch
|
| 93 |
+
google-i18n-address==3.1.1
|
| 94 |
+
# via dyff-schema
|
| 95 |
+
h11==0.16.0
|
| 96 |
+
# via
|
| 97 |
+
# httpcore
|
| 98 |
+
# uvicorn
|
| 99 |
+
hf-xet==1.2.0
|
| 100 |
+
# via huggingface-hub
|
| 101 |
+
httpcore==1.0.9
|
| 102 |
+
# via httpx
|
| 103 |
+
httptools==0.7.1
|
| 104 |
+
# via uvicorn
|
| 105 |
+
httpx==0.28.1
|
| 106 |
+
# via
|
| 107 |
+
# dyff-client
|
| 108 |
+
# jupyterlab
|
| 109 |
+
huggingface-hub==0.36.0
|
| 110 |
+
# via
|
| 111 |
+
# tokenizers
|
| 112 |
+
# transformers
|
| 113 |
+
hypothesis==6.148.1
|
| 114 |
+
# via
|
| 115 |
+
# dyff-schema
|
| 116 |
+
# hypothesis-jsonschema
|
| 117 |
+
hypothesis-jsonschema==0.23.1
|
| 118 |
+
# via dyff-schema
|
| 119 |
+
idna==3.11
|
| 120 |
+
# via
|
| 121 |
+
# anyio
|
| 122 |
+
# email-validator
|
| 123 |
+
# httpx
|
| 124 |
+
# jsonschema
|
| 125 |
+
# requests
|
| 126 |
+
ipykernel==7.1.0
|
| 127 |
+
# via jupyterlab
|
| 128 |
+
ipython==9.7.0
|
| 129 |
+
# via ipykernel
|
| 130 |
+
ipython-pygments-lexers==1.1.1
|
| 131 |
+
# via ipython
|
| 132 |
+
isodate==0.7.2
|
| 133 |
+
# via dyff-client
|
| 134 |
+
isoduration==20.11.0
|
| 135 |
+
# via jsonschema
|
| 136 |
+
jedi==0.19.2
|
| 137 |
+
# via ipython
|
| 138 |
+
jinja2==3.1.6
|
| 139 |
+
# via
|
| 140 |
+
# jupyter-server
|
| 141 |
+
# jupyterlab
|
| 142 |
+
# jupyterlab-server
|
| 143 |
+
# nbconvert
|
| 144 |
+
# torch
|
| 145 |
+
json5==0.12.1
|
| 146 |
+
# via jupyterlab-server
|
| 147 |
+
jsonpath-ng==1.7.0
|
| 148 |
+
# via
|
| 149 |
+
# dyff-client
|
| 150 |
+
# dyff-schema
|
| 151 |
+
jsonpointer==3.0.0
|
| 152 |
+
# via jsonschema
|
| 153 |
+
jsonschema==4.25.1
|
| 154 |
+
# via
|
| 155 |
+
# hypothesis-jsonschema
|
| 156 |
+
# jupyter-events
|
| 157 |
+
# jupyterlab-server
|
| 158 |
+
# nbformat
|
| 159 |
+
jsonschema-specifications==2025.9.1
|
| 160 |
+
# via jsonschema
|
| 161 |
+
jupyter-client==8.6.3
|
| 162 |
+
# via
|
| 163 |
+
# ipykernel
|
| 164 |
+
# jupyter-server
|
| 165 |
+
# nbclient
|
| 166 |
+
jupyter-core==5.9.1
|
| 167 |
+
# via
|
| 168 |
+
# ipykernel
|
| 169 |
+
# jupyter-client
|
| 170 |
+
# jupyter-server
|
| 171 |
+
# jupyterlab
|
| 172 |
+
# nbclient
|
| 173 |
+
# nbconvert
|
| 174 |
+
# nbformat
|
| 175 |
+
jupyter-events==0.12.0
|
| 176 |
+
# via jupyter-server
|
| 177 |
+
jupyter-lsp==2.3.0
|
| 178 |
+
# via jupyterlab
|
| 179 |
+
jupyter-server==2.17.0
|
| 180 |
+
# via
|
| 181 |
+
# jupyter-lsp
|
| 182 |
+
# jupyterlab
|
| 183 |
+
# jupyterlab-server
|
| 184 |
+
# notebook
|
| 185 |
+
# notebook-shim
|
| 186 |
+
jupyter-server-terminals==0.5.3
|
| 187 |
+
# via jupyter-server
|
| 188 |
+
jupyterlab==4.4.10
|
| 189 |
+
# via notebook
|
| 190 |
+
jupyterlab-pygments==0.3.0
|
| 191 |
+
# via nbconvert
|
| 192 |
+
jupyterlab-server==2.28.0
|
| 193 |
+
# via
|
| 194 |
+
# jupyterlab
|
| 195 |
+
# notebook
|
| 196 |
+
lark==1.3.1
|
| 197 |
+
# via rfc3987-syntax
|
| 198 |
+
lxml==6.0.2
|
| 199 |
+
# via dyff-audit
|
| 200 |
+
markupsafe==3.0.3
|
| 201 |
+
# via
|
| 202 |
+
# jinja2
|
| 203 |
+
# nbconvert
|
| 204 |
+
matplotlib-inline==0.2.1
|
| 205 |
+
# via
|
| 206 |
+
# ipykernel
|
| 207 |
+
# ipython
|
| 208 |
+
mistune==3.1.4
|
| 209 |
+
# via nbconvert
|
| 210 |
+
mpmath==1.3.0
|
| 211 |
+
# via sympy
|
| 212 |
+
nbclient==0.10.2
|
| 213 |
+
# via nbconvert
|
| 214 |
+
nbconvert==7.16.6
|
| 215 |
+
# via
|
| 216 |
+
# dyff-audit
|
| 217 |
+
# jupyter-server
|
| 218 |
+
nbformat==5.10.4
|
| 219 |
+
# via
|
| 220 |
+
# dyff-audit
|
| 221 |
+
# jupyter-server
|
| 222 |
+
# nbclient
|
| 223 |
+
# nbconvert
|
| 224 |
+
nest-asyncio==1.6.0
|
| 225 |
+
# via ipykernel
|
| 226 |
+
networkx==3.5
|
| 227 |
+
# via torch
|
| 228 |
+
notebook==7.4.7
|
| 229 |
+
# via dyff-audit
|
| 230 |
+
notebook-shim==0.2.4
|
| 231 |
+
# via
|
| 232 |
+
# jupyterlab
|
| 233 |
+
# notebook
|
| 234 |
+
numpy==1.26.4
|
| 235 |
+
# via
|
| 236 |
+
# -r requirements.in
|
| 237 |
+
# dyff-audit
|
| 238 |
+
# dyff-client
|
| 239 |
+
# dyff-schema
|
| 240 |
+
# pandas
|
| 241 |
+
# transformers
|
| 242 |
+
nvidia-cublas-cu12==12.8.4.1
|
| 243 |
+
# via
|
| 244 |
+
# nvidia-cudnn-cu12
|
| 245 |
+
# nvidia-cusolver-cu12
|
| 246 |
+
# torch
|
| 247 |
+
nvidia-cuda-cupti-cu12==12.8.90
|
| 248 |
+
# via torch
|
| 249 |
+
nvidia-cuda-nvrtc-cu12==12.8.93
|
| 250 |
+
# via torch
|
| 251 |
+
nvidia-cuda-runtime-cu12==12.8.90
|
| 252 |
+
# via torch
|
| 253 |
+
nvidia-cudnn-cu12==9.10.2.21
|
| 254 |
+
# via torch
|
| 255 |
+
nvidia-cufft-cu12==11.3.3.83
|
| 256 |
+
# via torch
|
| 257 |
+
nvidia-cufile-cu12==1.13.1.3
|
| 258 |
+
# via torch
|
| 259 |
+
nvidia-curand-cu12==10.3.9.90
|
| 260 |
+
# via torch
|
| 261 |
+
nvidia-cusolver-cu12==11.7.3.90
|
| 262 |
+
# via torch
|
| 263 |
+
nvidia-cusparse-cu12==12.5.8.93
|
| 264 |
+
# via
|
| 265 |
+
# nvidia-cusolver-cu12
|
| 266 |
+
# torch
|
| 267 |
+
nvidia-cusparselt-cu12==0.7.1
|
| 268 |
+
# via torch
|
| 269 |
+
nvidia-nccl-cu12==2.27.5
|
| 270 |
+
# via torch
|
| 271 |
+
nvidia-nvjitlink-cu12==12.8.93
|
| 272 |
+
# via
|
| 273 |
+
# nvidia-cufft-cu12
|
| 274 |
+
# nvidia-cusolver-cu12
|
| 275 |
+
# nvidia-cusparse-cu12
|
| 276 |
+
# torch
|
| 277 |
+
nvidia-nvshmem-cu12==3.3.20
|
| 278 |
+
# via torch
|
| 279 |
+
nvidia-nvtx-cu12==12.8.90
|
| 280 |
+
# via torch
|
| 281 |
+
packaging==25.0
|
| 282 |
+
# via
|
| 283 |
+
# huggingface-hub
|
| 284 |
+
# ipykernel
|
| 285 |
+
# jupyter-events
|
| 286 |
+
# jupyter-server
|
| 287 |
+
# jupyterlab
|
| 288 |
+
# jupyterlab-server
|
| 289 |
+
# nbconvert
|
| 290 |
+
# transformers
|
| 291 |
+
pandas==2.3.3
|
| 292 |
+
# via
|
| 293 |
+
# -r requirements.in
|
| 294 |
+
# dyff-audit
|
| 295 |
+
# dyff-client
|
| 296 |
+
pandocfilters==1.5.1
|
| 297 |
+
# via nbconvert
|
| 298 |
+
parso==0.8.5
|
| 299 |
+
# via jedi
|
| 300 |
+
pexpect==4.9.0
|
| 301 |
+
# via ipython
|
| 302 |
+
pillow==12.0.0
|
| 303 |
+
# via -r requirements.in
|
| 304 |
+
platformdirs==4.5.0
|
| 305 |
+
# via jupyter-core
|
| 306 |
+
ply==3.11
|
| 307 |
+
# via jsonpath-ng
|
| 308 |
+
prometheus-client==0.23.1
|
| 309 |
+
# via jupyter-server
|
| 310 |
+
prompt-toolkit==3.0.52
|
| 311 |
+
# via ipython
|
| 312 |
+
psutil==7.1.3
|
| 313 |
+
# via ipykernel
|
| 314 |
+
ptyprocess==0.7.0
|
| 315 |
+
# via
|
| 316 |
+
# pexpect
|
| 317 |
+
# terminado
|
| 318 |
+
pure-eval==0.2.3
|
| 319 |
+
# via stack-data
|
| 320 |
+
pyarrow==22.0.0
|
| 321 |
+
# via
|
| 322 |
+
# -r requirements.in
|
| 323 |
+
# dyff-audit
|
| 324 |
+
# dyff-client
|
| 325 |
+
# dyff-schema
|
| 326 |
+
pycparser==2.23
|
| 327 |
+
# via cffi
|
| 328 |
+
pydantic==2.5.0
|
| 329 |
+
# via
|
| 330 |
+
# -r requirements.in
|
| 331 |
+
# dyff-audit
|
| 332 |
+
# dyff-client
|
| 333 |
+
# dyff-schema
|
| 334 |
+
# fastapi
|
| 335 |
+
# pydantic-settings
|
| 336 |
+
pydantic-core==2.14.1
|
| 337 |
+
# via pydantic
|
| 338 |
+
pydantic-settings==2.0.3
|
| 339 |
+
# via -r requirements.in
|
| 340 |
+
pygments==2.19.2
|
| 341 |
+
# via
|
| 342 |
+
# ipython
|
| 343 |
+
# ipython-pygments-lexers
|
| 344 |
+
# nbconvert
|
| 345 |
+
python-dateutil==2.9.0.post0
|
| 346 |
+
# via
|
| 347 |
+
# arrow
|
| 348 |
+
# jupyter-client
|
| 349 |
+
# pandas
|
| 350 |
+
python-dotenv==0.21.0
|
| 351 |
+
# via
|
| 352 |
+
# -r requirements.in
|
| 353 |
+
# pydantic-settings
|
| 354 |
+
# uvicorn
|
| 355 |
+
python-json-logger==4.0.0
|
| 356 |
+
# via jupyter-events
|
| 357 |
+
python-multipart==0.0.6
|
| 358 |
+
# via -r requirements.in
|
| 359 |
+
pytz==2025.2
|
| 360 |
+
# via pandas
|
| 361 |
+
pyyaml==6.0.3
|
| 362 |
+
# via
|
| 363 |
+
# huggingface-hub
|
| 364 |
+
# jupyter-events
|
| 365 |
+
# transformers
|
| 366 |
+
# uvicorn
|
| 367 |
+
pyzmq==27.1.0
|
| 368 |
+
# via
|
| 369 |
+
# ipykernel
|
| 370 |
+
# jupyter-client
|
| 371 |
+
# jupyter-server
|
| 372 |
+
referencing==0.37.0
|
| 373 |
+
# via
|
| 374 |
+
# jsonschema
|
| 375 |
+
# jsonschema-specifications
|
| 376 |
+
# jupyter-events
|
| 377 |
+
regex==2025.11.3
|
| 378 |
+
# via transformers
|
| 379 |
+
requests==2.32.5
|
| 380 |
+
# via
|
| 381 |
+
# -r requirements.in
|
| 382 |
+
# azure-core
|
| 383 |
+
# google-i18n-address
|
| 384 |
+
# huggingface-hub
|
| 385 |
+
# jupyterlab-server
|
| 386 |
+
# transformers
|
| 387 |
+
rfc3339-validator==0.1.4
|
| 388 |
+
# via
|
| 389 |
+
# jsonschema
|
| 390 |
+
# jupyter-events
|
| 391 |
+
rfc3986-validator==0.1.1
|
| 392 |
+
# via
|
| 393 |
+
# jsonschema
|
| 394 |
+
# jupyter-events
|
| 395 |
+
rfc3987-syntax==1.1.0
|
| 396 |
+
# via jsonschema
|
| 397 |
+
rpds-py==0.29.0
|
| 398 |
+
# via
|
| 399 |
+
# jsonschema
|
| 400 |
+
# referencing
|
| 401 |
+
ruamel-yaml==0.18.16
|
| 402 |
+
# via dyff-audit
|
| 403 |
+
ruamel-yaml-clib==0.2.15
|
| 404 |
+
# via ruamel-yaml
|
| 405 |
+
safetensors==0.6.2
|
| 406 |
+
# via transformers
|
| 407 |
+
send2trash==1.8.3
|
| 408 |
+
# via jupyter-server
|
| 409 |
+
setuptools==80.9.0
|
| 410 |
+
# via
|
| 411 |
+
# jupyterlab
|
| 412 |
+
# torch
|
| 413 |
+
six==1.17.0
|
| 414 |
+
# via
|
| 415 |
+
# python-dateutil
|
| 416 |
+
# rfc3339-validator
|
| 417 |
+
sniffio==1.3.1
|
| 418 |
+
# via anyio
|
| 419 |
+
sortedcontainers==2.4.0
|
| 420 |
+
# via hypothesis
|
| 421 |
+
soupsieve==2.8
|
| 422 |
+
# via beautifulsoup4
|
| 423 |
+
stack-data==0.6.3
|
| 424 |
+
# via ipython
|
| 425 |
+
starlette==0.27.0
|
| 426 |
+
# via fastapi
|
| 427 |
+
sympy==1.14.0
|
| 428 |
+
# via torch
|
| 429 |
+
terminado==0.18.1
|
| 430 |
+
# via
|
| 431 |
+
# jupyter-server
|
| 432 |
+
# jupyter-server-terminals
|
| 433 |
+
tinycss2==1.4.0
|
| 434 |
+
# via bleach
|
| 435 |
+
tokenizers==0.22.1
|
| 436 |
+
# via transformers
|
| 437 |
+
torch==2.9.1
|
| 438 |
+
# via -r requirements.in
|
| 439 |
+
tornado==6.5.2
|
| 440 |
+
# via
|
| 441 |
+
# ipykernel
|
| 442 |
+
# jupyter-client
|
| 443 |
+
# jupyter-server
|
| 444 |
+
# jupyterlab
|
| 445 |
+
# notebook
|
| 446 |
+
# terminado
|
| 447 |
+
tqdm==4.67.1
|
| 448 |
+
# via
|
| 449 |
+
# dyff-client
|
| 450 |
+
# huggingface-hub
|
| 451 |
+
# transformers
|
| 452 |
+
traitlets==5.14.3
|
| 453 |
+
# via
|
| 454 |
+
# ipykernel
|
| 455 |
+
# ipython
|
| 456 |
+
# jupyter-client
|
| 457 |
+
# jupyter-core
|
| 458 |
+
# jupyter-events
|
| 459 |
+
# jupyter-server
|
| 460 |
+
# jupyterlab
|
| 461 |
+
# matplotlib-inline
|
| 462 |
+
# nbclient
|
| 463 |
+
# nbconvert
|
| 464 |
+
# nbformat
|
| 465 |
+
transformers==4.57.1
|
| 466 |
+
# via -r requirements.in
|
| 467 |
+
triton==3.5.1
|
| 468 |
+
# via torch
|
| 469 |
+
typing-extensions==4.15.0
|
| 470 |
+
# via
|
| 471 |
+
# azure-core
|
| 472 |
+
# beautifulsoup4
|
| 473 |
+
# fastapi
|
| 474 |
+
# huggingface-hub
|
| 475 |
+
# pydantic
|
| 476 |
+
# pydantic-core
|
| 477 |
+
# referencing
|
| 478 |
+
# torch
|
| 479 |
+
tzdata==2025.2
|
| 480 |
+
# via
|
| 481 |
+
# arrow
|
| 482 |
+
# pandas
|
| 483 |
+
uri-template==1.3.0
|
| 484 |
+
# via jsonschema
|
| 485 |
+
urllib3==2.5.0
|
| 486 |
+
# via requests
|
| 487 |
+
uvicorn==0.24.0
|
| 488 |
+
# via -r requirements.in
|
| 489 |
+
uvloop==0.22.1
|
| 490 |
+
# via uvicorn
|
| 491 |
+
watchfiles==1.1.1
|
| 492 |
+
# via uvicorn
|
| 493 |
+
wcwidth==0.2.14
|
| 494 |
+
# via prompt-toolkit
|
| 495 |
+
webcolors==25.10.0
|
| 496 |
+
# via jsonschema
|
| 497 |
+
webencodings==0.5.1
|
| 498 |
+
# via
|
| 499 |
+
# bleach
|
| 500 |
+
# tinycss2
|
| 501 |
+
websocket-client==1.9.0
|
| 502 |
+
# via jupyter-server
|
| 503 |
+
websockets==15.0.1
|
| 504 |
+
# via
|
| 505 |
+
# dyff-client
|
| 506 |
+
# uvicorn
|
| 507 |
+
|
| 508 |
+
numpy
|
| 509 |
+
pandas
|
| 510 |
+
scipy==1.12.0
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
opencv-contrib-python==4.11.0.86
|
| 514 |
+
av==14.0.1
|
| 515 |
+
decord
|
| 516 |
+
kornia==0.7.4
|
| 517 |
+
timm
|
| 518 |
+
einops
|
| 519 |
+
beartype
|
| 520 |
+
lpips==0.1.4
|
| 521 |
+
imageio
|
| 522 |
+
|
| 523 |
+
audiomentations==0.36.0
|
| 524 |
+
librosa==0.10.2.post1
|
| 525 |
+
soundfile==0.12.1
|
| 526 |
+
|
| 527 |
+
transformers
|
| 528 |
+
datasets
|
| 529 |
+
tqdm
|
| 530 |
+
opencv-python
|
response.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"logprobs":[-0.859380304813385,-1.2701971530914307,-2.1918208599090576,-1.69235098361969],"localizationMask":{"mediaType":"image/png","data":"iVBORw0KGgoAAAANSUhEUgAAA8AAAAKDAQAAAAD9Fl5AAAAAu0lEQVR4nO3NsREAMAgDMWD/nZMVKEwn1T5/FQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMCl3g5f+HC24TRhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAj70gwKsTlmdBwAAAABJRU5ErkJggg=="}}
|
scripts/generate_test_datasets.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
PyArrow Dataset Generator for ML Inference Service
|
| 4 |
+
|
| 5 |
+
Generates test datasets for academic challenges and model validation.
|
| 6 |
+
Creates 100 PyArrow datasets with various image types and test scenarios.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import base64
|
| 10 |
+
import json
|
| 11 |
+
import random
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, List, Any, Tuple
|
| 14 |
+
import io
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pyarrow as pa
|
| 18 |
+
import pyarrow.parquet as pq
|
| 19 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TestDatasetGenerator:
|
| 23 |
+
def __init__(self, output_dir: str = "test_datasets"):
|
| 24 |
+
self.output_dir = Path(output_dir)
|
| 25 |
+
self.output_dir.mkdir(exist_ok=True)
|
| 26 |
+
|
| 27 |
+
# ImageNet class labels (sample for testing)
|
| 28 |
+
self.imagenet_labels = [
|
| 29 |
+
"tench", "goldfish", "great_white_shark", "tiger_shark", "hammerhead",
|
| 30 |
+
"electric_ray", "stingray", "cock", "hen", "ostrich", "brambling",
|
| 31 |
+
"goldfinch", "house_finch", "junco", "indigo_bunting", "robin",
|
| 32 |
+
"bulbul", "jay", "magpie", "chickadee", "water_ouzel", "kite",
|
| 33 |
+
"bald_eagle", "vulture", "great_grey_owl", "European_fire_salamander",
|
| 34 |
+
"common_newt", "eft", "spotted_salamander", "axolotl", "bullfrog",
|
| 35 |
+
"tree_frog", "tailed_frog", "loggerhead", "leatherback_turtle",
|
| 36 |
+
"mud_turtle", "terrapin", "box_turtle", "banded_gecko", "common_iguana",
|
| 37 |
+
"American_chameleon", "whiptail", "agama", "frilled_lizard", "alligator_lizard",
|
| 38 |
+
"Gila_monster", "green_lizard", "African_chameleon", "Komodo_dragon",
|
| 39 |
+
"African_crocodile", "American_alligator", "triceratops", "thunder_snake"
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
def create_synthetic_image(self, width: int = 224, height: int = 224,
|
| 43 |
+
image_type: str = "random") -> Image.Image:
|
| 44 |
+
"""Create synthetic images for testing."""
|
| 45 |
+
if image_type == "random":
|
| 46 |
+
# Random noise image
|
| 47 |
+
array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
| 48 |
+
return Image.fromarray(array)
|
| 49 |
+
|
| 50 |
+
elif image_type == "geometric":
|
| 51 |
+
# Geometric patterns
|
| 52 |
+
img = Image.new('RGB', (width, height), color='white')
|
| 53 |
+
draw = ImageDraw.Draw(img)
|
| 54 |
+
|
| 55 |
+
# Draw random shapes
|
| 56 |
+
for _ in range(random.randint(3, 8)):
|
| 57 |
+
color = tuple(random.randint(0, 255) for _ in range(3))
|
| 58 |
+
shape_type = random.choice(['rectangle', 'ellipse'])
|
| 59 |
+
x1, y1 = random.randint(0, width//2), random.randint(0, height//2)
|
| 60 |
+
x2, y2 = x1 + random.randint(20, width//2), y1 + random.randint(20, height//2)
|
| 61 |
+
|
| 62 |
+
if shape_type == 'rectangle':
|
| 63 |
+
draw.rectangle([x1, y1, x2, y2], fill=color)
|
| 64 |
+
else:
|
| 65 |
+
draw.ellipse([x1, y1, x2, y2], fill=color)
|
| 66 |
+
|
| 67 |
+
return img
|
| 68 |
+
|
| 69 |
+
elif image_type == "gradient":
|
| 70 |
+
array = np.zeros((height, width, 3), dtype=np.uint8)
|
| 71 |
+
for i in range(height):
|
| 72 |
+
for j in range(width):
|
| 73 |
+
array[i, j] = [i * 255 // height, j * 255 // width, (i + j) * 255 // (height + width)]
|
| 74 |
+
return Image.fromarray(array)
|
| 75 |
+
|
| 76 |
+
elif image_type == "text":
|
| 77 |
+
img = Image.new('RGB', (width, height), color='white')
|
| 78 |
+
draw = ImageDraw.Draw(img)
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
font = ImageFont.load_default()
|
| 82 |
+
except:
|
| 83 |
+
font = None
|
| 84 |
+
|
| 85 |
+
text = f"Test Image {random.randint(1, 1000)}"
|
| 86 |
+
draw.text((width//4, height//2), text, fill='black', font=font)
|
| 87 |
+
return img
|
| 88 |
+
|
| 89 |
+
else:
|
| 90 |
+
color = tuple(random.randint(0, 255) for _ in range(3))
|
| 91 |
+
return Image.new('RGB', (width, height), color=color)
|
| 92 |
+
|
| 93 |
+
def image_to_base64(self, image: Image.Image, format: str = "JPEG") -> str:
|
| 94 |
+
"""Convert PIL image to base64 string."""
|
| 95 |
+
buffer = io.BytesIO()
|
| 96 |
+
image.save(buffer, format=format)
|
| 97 |
+
image_bytes = buffer.getvalue()
|
| 98 |
+
return base64.b64encode(image_bytes).decode('utf-8')
|
| 99 |
+
|
| 100 |
+
def create_api_request(self, image_b64: str, media_type: str = "image/jpeg") -> Dict[str, Any]:
|
| 101 |
+
"""Create API request structure matching your service."""
|
| 102 |
+
return {
|
| 103 |
+
"image": {
|
| 104 |
+
"mediaType": media_type,
|
| 105 |
+
"data": image_b64
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def create_expected_response(self, model_name: str = "microsoft/resnet-18",
|
| 110 |
+
media_type: str = "image/jpeg") -> Dict[str, Any]:
|
| 111 |
+
"""Create expected response structure."""
|
| 112 |
+
prediction = random.choice(self.imagenet_labels)
|
| 113 |
+
return {
|
| 114 |
+
"prediction": prediction,
|
| 115 |
+
"confidence": round(random.uniform(0.3, 0.99), 4),
|
| 116 |
+
"predicted_label": random.randint(0, len(self.imagenet_labels) - 1),
|
| 117 |
+
"model": model_name,
|
| 118 |
+
"mediaType": media_type
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
def generate_standard_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
|
| 122 |
+
"""Generate standard test cases with normal images."""
|
| 123 |
+
datasets = []
|
| 124 |
+
|
| 125 |
+
for i in range(count):
|
| 126 |
+
image_types = ["random", "geometric", "gradient", "text", "solid"]
|
| 127 |
+
sizes = [(224, 224), (256, 256), (299, 299), (384, 384)]
|
| 128 |
+
formats = [("JPEG", "image/jpeg"), ("PNG", "image/png")]
|
| 129 |
+
|
| 130 |
+
records = []
|
| 131 |
+
for j in range(random.randint(5, 20)): # 5-20 images per dataset
|
| 132 |
+
img_type = random.choice(image_types)
|
| 133 |
+
size = random.choice(sizes)
|
| 134 |
+
format_info = random.choice(formats)
|
| 135 |
+
|
| 136 |
+
image = self.create_synthetic_image(size[0], size[1], img_type)
|
| 137 |
+
image_b64 = self.image_to_base64(image, format_info[0])
|
| 138 |
+
|
| 139 |
+
api_request = self.create_api_request(image_b64, format_info[1])
|
| 140 |
+
expected_response = self.create_expected_response()
|
| 141 |
+
|
| 142 |
+
record = {
|
| 143 |
+
"dataset_id": f"standard_{i:03d}",
|
| 144 |
+
"image_id": f"img_{j:03d}",
|
| 145 |
+
"image_type": img_type,
|
| 146 |
+
"image_size": f"{size[0]}x{size[1]}",
|
| 147 |
+
"format": format_info[0],
|
| 148 |
+
"media_type": format_info[1],
|
| 149 |
+
"api_request": json.dumps(api_request),
|
| 150 |
+
"expected_response": json.dumps(expected_response),
|
| 151 |
+
"test_category": "standard",
|
| 152 |
+
"difficulty": "normal"
|
| 153 |
+
}
|
| 154 |
+
records.append(record)
|
| 155 |
+
|
| 156 |
+
datasets.append({
|
| 157 |
+
"name": f"standard_test_{i:03d}",
|
| 158 |
+
"category": "standard",
|
| 159 |
+
"description": f"Standard test dataset {i+1} with {len(records)} images",
|
| 160 |
+
"records": records
|
| 161 |
+
})
|
| 162 |
+
|
| 163 |
+
return datasets
|
| 164 |
+
|
| 165 |
+
def generate_edge_case_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
|
| 166 |
+
"""Generate datasets for edge case scenarios."""
|
| 167 |
+
datasets = []
|
| 168 |
+
|
| 169 |
+
for i in range(count):
|
| 170 |
+
records = []
|
| 171 |
+
edge_cases = [
|
| 172 |
+
{"type": "tiny", "size": (32, 32), "difficulty": "high"},
|
| 173 |
+
{"type": "huge", "size": (2048, 2048), "difficulty": "high"},
|
| 174 |
+
{"type": "extreme_aspect", "size": (1000, 50), "difficulty": "medium"},
|
| 175 |
+
{"type": "single_pixel", "size": (1, 1), "difficulty": "extreme"},
|
| 176 |
+
{"type": "corrupted_base64", "size": (224, 224), "difficulty": "extreme"}
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
for j, edge_case in enumerate(edge_cases):
|
| 180 |
+
if edge_case["type"] == "corrupted_base64":
|
| 181 |
+
image = self.create_synthetic_image(224, 224, "random")
|
| 182 |
+
image_b64 = self.image_to_base64(image, "JPEG")
|
| 183 |
+
corrupted_b64 = image_b64[:-20] + "CORRUPTED_DATA"
|
| 184 |
+
api_request = self.create_api_request(corrupted_b64)
|
| 185 |
+
expected_response = {
|
| 186 |
+
"error": "Invalid image data",
|
| 187 |
+
"status": "failed"
|
| 188 |
+
}
|
| 189 |
+
else:
|
| 190 |
+
image = self.create_synthetic_image(
|
| 191 |
+
edge_case["size"][0], edge_case["size"][1], "random"
|
| 192 |
+
)
|
| 193 |
+
image_b64 = self.image_to_base64(image, "PNG")
|
| 194 |
+
api_request = self.create_api_request(image_b64, "image/png")
|
| 195 |
+
expected_response = self.create_expected_response()
|
| 196 |
+
|
| 197 |
+
record = {
|
| 198 |
+
"dataset_id": f"edge_{i:03d}",
|
| 199 |
+
"image_id": f"edge_{j:03d}",
|
| 200 |
+
"image_type": edge_case["type"],
|
| 201 |
+
"image_size": f"{edge_case['size'][0]}x{edge_case['size'][1]}",
|
| 202 |
+
"format": "PNG",
|
| 203 |
+
"media_type": "image/png",
|
| 204 |
+
"api_request": json.dumps(api_request),
|
| 205 |
+
"expected_response": json.dumps(expected_response),
|
| 206 |
+
"test_category": "edge_case",
|
| 207 |
+
"difficulty": edge_case["difficulty"]
|
| 208 |
+
}
|
| 209 |
+
records.append(record)
|
| 210 |
+
|
| 211 |
+
datasets.append({
|
| 212 |
+
"name": f"edge_case_{i:03d}",
|
| 213 |
+
"category": "edge_case",
|
| 214 |
+
"description": f"Edge case dataset {i+1} with challenging scenarios",
|
| 215 |
+
"records": records
|
| 216 |
+
})
|
| 217 |
+
|
| 218 |
+
return datasets
|
| 219 |
+
|
| 220 |
+
def generate_performance_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
|
| 221 |
+
"""Generate performance benchmark datasets."""
|
| 222 |
+
datasets = []
|
| 223 |
+
|
| 224 |
+
for i in range(count):
|
| 225 |
+
batch_sizes = [1, 5, 10, 25, 50, 100]
|
| 226 |
+
batch_size = random.choice(batch_sizes)
|
| 227 |
+
|
| 228 |
+
records = []
|
| 229 |
+
for j in range(batch_size):
|
| 230 |
+
image = self.create_synthetic_image(224, 224, "random")
|
| 231 |
+
image_b64 = self.image_to_base64(image, "JPEG")
|
| 232 |
+
api_request = self.create_api_request(image_b64)
|
| 233 |
+
expected_response = self.create_expected_response()
|
| 234 |
+
|
| 235 |
+
record = {
|
| 236 |
+
"dataset_id": f"perf_{i:03d}",
|
| 237 |
+
"image_id": f"batch_{j:03d}",
|
| 238 |
+
"image_type": "performance_test",
|
| 239 |
+
"image_size": "224x224",
|
| 240 |
+
"format": "JPEG",
|
| 241 |
+
"media_type": "image/jpeg",
|
| 242 |
+
"api_request": json.dumps(api_request),
|
| 243 |
+
"expected_response": json.dumps(expected_response),
|
| 244 |
+
"test_category": "performance",
|
| 245 |
+
"difficulty": "normal",
|
| 246 |
+
"batch_size": batch_size,
|
| 247 |
+
"expected_max_latency_ms": batch_size * 100
|
| 248 |
+
}
|
| 249 |
+
records.append(record)
|
| 250 |
+
|
| 251 |
+
datasets.append({
|
| 252 |
+
"name": f"performance_test_{i:03d}",
|
| 253 |
+
"category": "performance",
|
| 254 |
+
"description": f"Performance dataset {i+1} with batch size {batch_size}",
|
| 255 |
+
"records": records
|
| 256 |
+
})
|
| 257 |
+
|
| 258 |
+
return datasets
|
| 259 |
+
|
| 260 |
+
def generate_model_comparison_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
|
| 261 |
+
"""Generate datasets for comparing different models."""
|
| 262 |
+
datasets = []
|
| 263 |
+
|
| 264 |
+
model_types = [
|
| 265 |
+
"microsoft/resnet-18", "microsoft/resnet-50", "google/vit-base-patch16-224",
|
| 266 |
+
"facebook/convnext-tiny-224", "microsoft/swin-tiny-patch4-window7-224"
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
for i in range(count):
|
| 270 |
+
# Same images tested across different model types
|
| 271 |
+
base_images = []
|
| 272 |
+
for _ in range(10): # 10 base images per comparison dataset
|
| 273 |
+
image = self.create_synthetic_image(224, 224, "geometric")
|
| 274 |
+
base_images.append(self.image_to_base64(image, "JPEG"))
|
| 275 |
+
|
| 276 |
+
records = []
|
| 277 |
+
for j, model in enumerate(model_types):
|
| 278 |
+
for k, image_b64 in enumerate(base_images):
|
| 279 |
+
api_request = self.create_api_request(image_b64)
|
| 280 |
+
expected_response = self.create_expected_response(model)
|
| 281 |
+
|
| 282 |
+
record = {
|
| 283 |
+
"dataset_id": f"comparison_{i:03d}",
|
| 284 |
+
"image_id": f"img_{k:03d}_model_{j}",
|
| 285 |
+
"image_type": "comparison_base",
|
| 286 |
+
"image_size": "224x224",
|
| 287 |
+
"format": "JPEG",
|
| 288 |
+
"media_type": "image/jpeg",
|
| 289 |
+
"api_request": json.dumps(api_request),
|
| 290 |
+
"expected_response": json.dumps(expected_response),
|
| 291 |
+
"test_category": "model_comparison",
|
| 292 |
+
"difficulty": "normal",
|
| 293 |
+
"model_type": model,
|
| 294 |
+
"comparison_group": k
|
| 295 |
+
}
|
| 296 |
+
records.append(record)
|
| 297 |
+
|
| 298 |
+
datasets.append({
|
| 299 |
+
"name": f"model_comparison_{i:03d}",
|
| 300 |
+
"category": "model_comparison",
|
| 301 |
+
"description": f"Model comparison dataset {i+1} testing {len(model_types)} models",
|
| 302 |
+
"records": records
|
| 303 |
+
})
|
| 304 |
+
|
| 305 |
+
return datasets
|
| 306 |
+
|
| 307 |
+
def save_dataset_to_parquet(self, dataset: Dict[str, Any]):
|
| 308 |
+
"""Save a dataset to PyArrow Parquet format."""
|
| 309 |
+
records = dataset["records"]
|
| 310 |
+
|
| 311 |
+
# Convert to PyArrow table
|
| 312 |
+
table = pa.table({
|
| 313 |
+
"dataset_id": [r["dataset_id"] for r in records],
|
| 314 |
+
"image_id": [r["image_id"] for r in records],
|
| 315 |
+
"image_type": [r["image_type"] for r in records],
|
| 316 |
+
"image_size": [r["image_size"] for r in records],
|
| 317 |
+
"format": [r["format"] for r in records],
|
| 318 |
+
"media_type": [r["media_type"] for r in records],
|
| 319 |
+
"api_request": [r["api_request"] for r in records],
|
| 320 |
+
"expected_response": [r["expected_response"] for r in records],
|
| 321 |
+
"test_category": [r["test_category"] for r in records],
|
| 322 |
+
"difficulty": [r["difficulty"] for r in records],
|
| 323 |
+
# Optional fields with defaults
|
| 324 |
+
"batch_size": [r.get("batch_size", 1) for r in records],
|
| 325 |
+
"expected_max_latency_ms": [r.get("expected_max_latency_ms", 1000) for r in records],
|
| 326 |
+
"model_type": [r.get("model_type", "microsoft/resnet-18") for r in records],
|
| 327 |
+
"comparison_group": [r.get("comparison_group", 0) for r in records]
|
| 328 |
+
})
|
| 329 |
+
|
| 330 |
+
output_path = self.output_dir / f"{dataset['name']}.parquet"
|
| 331 |
+
pq.write_table(table, output_path)
|
| 332 |
+
|
| 333 |
+
# Save metadata as JSON
|
| 334 |
+
metadata = {
|
| 335 |
+
"name": dataset["name"],
|
| 336 |
+
"category": dataset["category"],
|
| 337 |
+
"description": dataset["description"],
|
| 338 |
+
"record_count": len(records),
|
| 339 |
+
"file_size_mb": round(output_path.stat().st_size / (1024 * 1024), 2),
|
| 340 |
+
"schema": [field.name for field in table.schema]
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
metadata_path = self.output_dir / f"{dataset['name']}_metadata.json"
|
| 344 |
+
with open(metadata_path, 'w') as f:
|
| 345 |
+
json.dump(metadata, f, indent=2)
|
| 346 |
+
|
| 347 |
+
def generate_all_datasets(self):
|
| 348 |
+
"""Generate all 100 datasets."""
|
| 349 |
+
print(" Starting dataset generation...")
|
| 350 |
+
|
| 351 |
+
print("📊 Generating standard test datasets (25)...")
|
| 352 |
+
standard_datasets = self.generate_standard_datasets(25)
|
| 353 |
+
for dataset in standard_datasets:
|
| 354 |
+
self.save_dataset_to_parquet(dataset)
|
| 355 |
+
|
| 356 |
+
print("⚡ Generating edge case datasets (25)...")
|
| 357 |
+
edge_datasets = self.generate_edge_case_datasets(25)
|
| 358 |
+
for dataset in edge_datasets:
|
| 359 |
+
self.save_dataset_to_parquet(dataset)
|
| 360 |
+
|
| 361 |
+
print("🏁 Generating performance datasets (25)...")
|
| 362 |
+
performance_datasets = self.generate_performance_datasets(25)
|
| 363 |
+
for dataset in performance_datasets:
|
| 364 |
+
self.save_dataset_to_parquet(dataset)
|
| 365 |
+
|
| 366 |
+
print("🔄 Generating model comparison datasets (25)...")
|
| 367 |
+
comparison_datasets = self.generate_model_comparison_datasets(25)
|
| 368 |
+
for dataset in comparison_datasets:
|
| 369 |
+
self.save_dataset_to_parquet(dataset)
|
| 370 |
+
|
| 371 |
+
print(f"✅ Generated 100 datasets in {self.output_dir}/")
|
| 372 |
+
|
| 373 |
+
self.generate_summary()
|
| 374 |
+
|
| 375 |
+
def generate_summary(self):
|
| 376 |
+
"""Generate a summary of all datasets."""
|
| 377 |
+
summary = {
|
| 378 |
+
"total_datasets": 100,
|
| 379 |
+
"categories": {
|
| 380 |
+
"standard": 25,
|
| 381 |
+
"edge_case": 25,
|
| 382 |
+
"performance": 25,
|
| 383 |
+
"model_comparison": 25
|
| 384 |
+
},
|
| 385 |
+
"dataset_info": [],
|
| 386 |
+
"usage_instructions": {
|
| 387 |
+
"loading": "Use pyarrow.parquet.read_table('dataset.parquet')",
|
| 388 |
+
"testing": "Run python scripts/test_datasets.py",
|
| 389 |
+
"api_endpoint": "POST /predict/resnet",
|
| 390 |
+
"request_format": "See api_request column in datasets"
|
| 391 |
+
}
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
# Add individual dataset info
|
| 395 |
+
for parquet_file in self.output_dir.glob("*.parquet"):
|
| 396 |
+
metadata_file = self.output_dir / f"{parquet_file.stem}_metadata.json"
|
| 397 |
+
if metadata_file.exists():
|
| 398 |
+
with open(metadata_file, 'r') as f:
|
| 399 |
+
metadata = json.load(f)
|
| 400 |
+
summary["dataset_info"].append(metadata)
|
| 401 |
+
|
| 402 |
+
summary_path = self.output_dir / "datasets_summary.json"
|
| 403 |
+
with open(summary_path, 'w') as f:
|
| 404 |
+
json.dump(summary, f, indent=2)
|
| 405 |
+
|
| 406 |
+
print(f"📋 Summary saved to {summary_path}")
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
if __name__ == "__main__":
|
| 410 |
+
generator = TestDatasetGenerator()
|
| 411 |
+
generator.generate_all_datasets()
|
scripts/model_download.bash
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mkdir -p "models/microsoft/resnet-18"
|
| 2 |
+
|
| 3 |
+
python - <<'PY'
|
| 4 |
+
from huggingface_hub import snapshot_download
|
| 5 |
+
snapshot_download(
|
| 6 |
+
repo_id="microsoft/resnet-18",
|
| 7 |
+
local_dir="models/microsoft/resnet-18",
|
| 8 |
+
local_dir_use_symlinks=False # copies files; safer for containers
|
| 9 |
+
)
|
| 10 |
+
PY
|
scripts/test_datasets.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Dataset Tester for ML Inference Service
|
| 4 |
+
|
| 5 |
+
Tests the generated PyArrow datasets against the running ML inference service.
|
| 6 |
+
Validates API requests/responses and measures performance metrics.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import asyncio
|
| 12 |
+
import statistics
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict, List, Any, Optional
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
import pyarrow.parquet as pq
|
| 18 |
+
import requests
|
| 19 |
+
import pandas as pd
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DatasetTester:
|
| 23 |
+
def __init__(self, base_url: str = "http://127.0.0.1:8000", datasets_dir: str = "test_datasets"):
|
| 24 |
+
self.base_url = base_url.rstrip('/')
|
| 25 |
+
self.datasets_dir = Path(datasets_dir)
|
| 26 |
+
self.endpoint = f"{self.base_url}/predict"
|
| 27 |
+
self.results = []
|
| 28 |
+
|
| 29 |
+
def load_dataset(self, dataset_path: Path) -> pd.DataFrame:
|
| 30 |
+
"""Load a PyArrow dataset."""
|
| 31 |
+
table = pq.read_table(dataset_path)
|
| 32 |
+
return table.to_pandas()
|
| 33 |
+
|
| 34 |
+
def test_api_connection(self) -> bool:
|
| 35 |
+
"""Test if the API is running and accessible."""
|
| 36 |
+
try:
|
| 37 |
+
response = requests.get(f"{self.base_url}/docs", timeout=5)
|
| 38 |
+
return response.status_code == 200
|
| 39 |
+
except requests.RequestException:
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
def send_prediction_request(self, api_request_json: str) -> Dict[str, Any]:
|
| 43 |
+
"""Send a single prediction request to the API."""
|
| 44 |
+
try:
|
| 45 |
+
request_data = json.loads(api_request_json)
|
| 46 |
+
start_time = time.time()
|
| 47 |
+
|
| 48 |
+
response = requests.post(
|
| 49 |
+
self.endpoint,
|
| 50 |
+
json=request_data,
|
| 51 |
+
headers={"Content-Type": "application/json"},
|
| 52 |
+
timeout=30
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
end_time = time.time()
|
| 56 |
+
latency_ms = (end_time - start_time) * 1000
|
| 57 |
+
|
| 58 |
+
return {
|
| 59 |
+
"success": response.status_code == 200,
|
| 60 |
+
"status_code": response.status_code,
|
| 61 |
+
"response": response.json() if response.status_code == 200 else response.text,
|
| 62 |
+
"latency_ms": round(latency_ms, 2),
|
| 63 |
+
"error": None
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
except requests.RequestException as e:
|
| 67 |
+
return {
|
| 68 |
+
"success": False,
|
| 69 |
+
"status_code": None,
|
| 70 |
+
"response": None,
|
| 71 |
+
"latency_ms": None,
|
| 72 |
+
"error": str(e)
|
| 73 |
+
}
|
| 74 |
+
except json.JSONDecodeError as e:
|
| 75 |
+
return {
|
| 76 |
+
"success": False,
|
| 77 |
+
"status_code": None,
|
| 78 |
+
"response": None,
|
| 79 |
+
"latency_ms": None,
|
| 80 |
+
"error": f"JSON decode error: {str(e)}"
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
def validate_response(self, actual_response: Dict[str, Any],
|
| 84 |
+
expected_response_json: str) -> Dict[str, Any]:
|
| 85 |
+
"""Validate API response against expected response."""
|
| 86 |
+
try:
|
| 87 |
+
expected = json.loads(expected_response_json)
|
| 88 |
+
|
| 89 |
+
validation = {
|
| 90 |
+
"structure_valid": True,
|
| 91 |
+
"field_errors": []
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# Check required fields exist
|
| 95 |
+
required_fields = ["prediction", "confidence", "predicted_label", "model", "mediaType"]
|
| 96 |
+
for field in required_fields:
|
| 97 |
+
if field not in actual_response:
|
| 98 |
+
validation["structure_valid"] = False
|
| 99 |
+
validation["field_errors"].append(f"Missing field: {field}")
|
| 100 |
+
|
| 101 |
+
# Validate field types
|
| 102 |
+
if "confidence" in actual_response:
|
| 103 |
+
if not isinstance(actual_response["confidence"], (int, float)):
|
| 104 |
+
validation["field_errors"].append("confidence must be numeric")
|
| 105 |
+
elif not (0 <= actual_response["confidence"] <= 1):
|
| 106 |
+
validation["field_errors"].append("confidence must be between 0 and 1")
|
| 107 |
+
|
| 108 |
+
if "predicted_label" in actual_response:
|
| 109 |
+
if not isinstance(actual_response["predicted_label"], int):
|
| 110 |
+
validation["field_errors"].append("predicted_label must be integer")
|
| 111 |
+
|
| 112 |
+
return validation
|
| 113 |
+
|
| 114 |
+
except json.JSONDecodeError:
|
| 115 |
+
return {
|
| 116 |
+
"structure_valid": False,
|
| 117 |
+
"field_errors": ["Invalid expected response JSON"]
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
def test_dataset(self, dataset_path: Path, max_samples: Optional[int] = None) -> Dict[str, Any]:
|
| 121 |
+
"""Test a single dataset."""
|
| 122 |
+
print(f"📊 Testing dataset: {dataset_path.name}")
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
df = self.load_dataset(dataset_path)
|
| 126 |
+
if max_samples:
|
| 127 |
+
df = df.head(max_samples)
|
| 128 |
+
|
| 129 |
+
results = {
|
| 130 |
+
"dataset_name": dataset_path.stem,
|
| 131 |
+
"total_samples": len(df),
|
| 132 |
+
"tested_samples": 0,
|
| 133 |
+
"successful_requests": 0,
|
| 134 |
+
"failed_requests": 0,
|
| 135 |
+
"validation_errors": 0,
|
| 136 |
+
"latencies_ms": [],
|
| 137 |
+
"errors": [],
|
| 138 |
+
"category": df['test_category'].iloc[0] if not df.empty else "unknown"
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
for idx, row in df.iterrows():
|
| 142 |
+
print(f" Testing sample {idx + 1}/{len(df)}", end="\r")
|
| 143 |
+
|
| 144 |
+
# Send API request
|
| 145 |
+
api_result = self.send_prediction_request(row['api_request'])
|
| 146 |
+
results["tested_samples"] += 1
|
| 147 |
+
|
| 148 |
+
if api_result["success"]:
|
| 149 |
+
results["successful_requests"] += 1
|
| 150 |
+
results["latencies_ms"].append(api_result["latency_ms"])
|
| 151 |
+
|
| 152 |
+
# Validate response structure
|
| 153 |
+
validation = self.validate_response(
|
| 154 |
+
api_result["response"],
|
| 155 |
+
row['expected_response']
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if not validation["structure_valid"]:
|
| 159 |
+
results["validation_errors"] += 1
|
| 160 |
+
results["errors"].append({
|
| 161 |
+
"sample_id": row['image_id'],
|
| 162 |
+
"type": "validation_error",
|
| 163 |
+
"details": validation["field_errors"]
|
| 164 |
+
})
|
| 165 |
+
|
| 166 |
+
else:
|
| 167 |
+
results["failed_requests"] += 1
|
| 168 |
+
results["errors"].append({
|
| 169 |
+
"sample_id": row['image_id'],
|
| 170 |
+
"type": "request_failed",
|
| 171 |
+
"status_code": api_result["status_code"],
|
| 172 |
+
"error": api_result["error"]
|
| 173 |
+
})
|
| 174 |
+
|
| 175 |
+
# Calculate statistics
|
| 176 |
+
if results["latencies_ms"]:
|
| 177 |
+
results["avg_latency_ms"] = round(statistics.mean(results["latencies_ms"]), 2)
|
| 178 |
+
results["min_latency_ms"] = round(min(results["latencies_ms"]), 2)
|
| 179 |
+
results["max_latency_ms"] = round(max(results["latencies_ms"]), 2)
|
| 180 |
+
results["median_latency_ms"] = round(statistics.median(results["latencies_ms"]), 2)
|
| 181 |
+
else:
|
| 182 |
+
results.update({
|
| 183 |
+
"avg_latency_ms": None,
|
| 184 |
+
"min_latency_ms": None,
|
| 185 |
+
"max_latency_ms": None,
|
| 186 |
+
"median_latency_ms": None
|
| 187 |
+
})
|
| 188 |
+
|
| 189 |
+
results["success_rate"] = round(
|
| 190 |
+
results["successful_requests"] / results["tested_samples"] * 100, 2
|
| 191 |
+
) if results["tested_samples"] > 0 else 0
|
| 192 |
+
|
| 193 |
+
print(f"\n ✅ Completed: {results['success_rate']}% success rate")
|
| 194 |
+
return results
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"\n ❌ Failed to test dataset: {str(e)}")
|
| 198 |
+
return {
|
| 199 |
+
"dataset_name": dataset_path.stem,
|
| 200 |
+
"error": str(e),
|
| 201 |
+
"success_rate": 0
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
def test_all_datasets(self, max_samples_per_dataset: Optional[int] = None,
|
| 205 |
+
category_filter: Optional[str] = None) -> Dict[str, Any]:
|
| 206 |
+
"""Test all datasets or filtered by category."""
|
| 207 |
+
if not self.test_api_connection():
|
| 208 |
+
print("❌ API is not accessible. Please start the service first:")
|
| 209 |
+
print(" uvicorn main:app --reload")
|
| 210 |
+
return {"error": "API not accessible"}
|
| 211 |
+
|
| 212 |
+
print(f" Starting dataset testing against {self.endpoint}")
|
| 213 |
+
|
| 214 |
+
parquet_files = list(self.datasets_dir.glob("*.parquet"))
|
| 215 |
+
if not parquet_files:
|
| 216 |
+
print(f"❌ No datasets found in {self.datasets_dir}")
|
| 217 |
+
return {"error": "No datasets found"}
|
| 218 |
+
|
| 219 |
+
if category_filter:
|
| 220 |
+
parquet_files = [f for f in parquet_files if category_filter in f.name]
|
| 221 |
+
|
| 222 |
+
print(f" Found {len(parquet_files)} datasets to test")
|
| 223 |
+
|
| 224 |
+
all_results = []
|
| 225 |
+
start_time = time.time()
|
| 226 |
+
|
| 227 |
+
for dataset_file in parquet_files:
|
| 228 |
+
result = self.test_dataset(dataset_file, max_samples_per_dataset)
|
| 229 |
+
all_results.append(result)
|
| 230 |
+
|
| 231 |
+
end_time = time.time()
|
| 232 |
+
total_time = end_time - start_time
|
| 233 |
+
|
| 234 |
+
summary = self.generate_summary(all_results, total_time)
|
| 235 |
+
|
| 236 |
+
self.save_results(summary, all_results)
|
| 237 |
+
|
| 238 |
+
return summary
|
| 239 |
+
|
| 240 |
+
def generate_summary(self, results: List[Dict[str, Any]], total_time: float) -> Dict[str, Any]:
|
| 241 |
+
"""Generate summary of all test results."""
|
| 242 |
+
successful_datasets = [r for r in results if r.get("success_rate", 0) > 0]
|
| 243 |
+
failed_datasets = [r for r in results if r.get("error") or r.get("success_rate", 0) == 0]
|
| 244 |
+
|
| 245 |
+
total_samples = sum(r.get("tested_samples", 0) for r in results)
|
| 246 |
+
total_successful = sum(r.get("successful_requests", 0) for r in results)
|
| 247 |
+
total_failed = sum(r.get("failed_requests", 0) for r in results)
|
| 248 |
+
|
| 249 |
+
all_latencies = []
|
| 250 |
+
for r in results:
|
| 251 |
+
all_latencies.extend(r.get("latencies_ms", []))
|
| 252 |
+
|
| 253 |
+
summary = {
|
| 254 |
+
"test_summary": {
|
| 255 |
+
"total_datasets": len(results),
|
| 256 |
+
"successful_datasets": len(successful_datasets),
|
| 257 |
+
"failed_datasets": len(failed_datasets),
|
| 258 |
+
"total_samples_tested": total_samples,
|
| 259 |
+
"total_successful_requests": total_successful,
|
| 260 |
+
"total_failed_requests": total_failed,
|
| 261 |
+
"overall_success_rate": round(
|
| 262 |
+
total_successful / total_samples * 100, 2
|
| 263 |
+
) if total_samples > 0 else 0,
|
| 264 |
+
"total_test_time_seconds": round(total_time, 2)
|
| 265 |
+
},
|
| 266 |
+
"performance_metrics": {
|
| 267 |
+
"avg_latency_ms": round(statistics.mean(all_latencies), 2) if all_latencies else None,
|
| 268 |
+
"median_latency_ms": round(statistics.median(all_latencies), 2) if all_latencies else None,
|
| 269 |
+
"min_latency_ms": round(min(all_latencies), 2) if all_latencies else None,
|
| 270 |
+
"max_latency_ms": round(max(all_latencies), 2) if all_latencies else None,
|
| 271 |
+
"requests_per_second": round(
|
| 272 |
+
total_successful / total_time, 2
|
| 273 |
+
) if total_time > 0 else 0
|
| 274 |
+
},
|
| 275 |
+
"category_breakdown": {},
|
| 276 |
+
"failed_datasets": [r["dataset_name"] for r in failed_datasets]
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
categories = {}
|
| 280 |
+
for result in results:
|
| 281 |
+
category = result.get("category", "unknown")
|
| 282 |
+
if category not in categories:
|
| 283 |
+
categories[category] = {
|
| 284 |
+
"count": 0,
|
| 285 |
+
"success_rates": [],
|
| 286 |
+
"avg_success_rate": 0
|
| 287 |
+
}
|
| 288 |
+
categories[category]["count"] += 1
|
| 289 |
+
categories[category]["success_rates"].append(result.get("success_rate", 0))
|
| 290 |
+
|
| 291 |
+
for category, data in categories.items():
|
| 292 |
+
data["avg_success_rate"] = round(
|
| 293 |
+
statistics.mean(data["success_rates"]), 2
|
| 294 |
+
) if data["success_rates"] else 0
|
| 295 |
+
|
| 296 |
+
summary["category_breakdown"] = categories
|
| 297 |
+
|
| 298 |
+
return summary
|
| 299 |
+
|
| 300 |
+
def save_results(self, summary: Dict[str, Any], detailed_results: List[Dict[str, Any]]):
|
| 301 |
+
"""Save test results to files."""
|
| 302 |
+
results_dir = Path("test_results")
|
| 303 |
+
results_dir.mkdir(exist_ok=True)
|
| 304 |
+
|
| 305 |
+
timestamp = int(time.time())
|
| 306 |
+
|
| 307 |
+
# Save summary
|
| 308 |
+
summary_path = results_dir / f"test_summary_{timestamp}.json"
|
| 309 |
+
with open(summary_path, 'w') as f:
|
| 310 |
+
json.dump(summary, f, indent=2)
|
| 311 |
+
|
| 312 |
+
# Save detailed results
|
| 313 |
+
detailed_path = results_dir / f"test_detailed_{timestamp}.json"
|
| 314 |
+
with open(detailed_path, 'w') as f:
|
| 315 |
+
json.dump(detailed_results, f, indent=2)
|
| 316 |
+
|
| 317 |
+
print(f" Results saved:")
|
| 318 |
+
print(f" Summary: {summary_path}")
|
| 319 |
+
print(f" Details: {detailed_path}")
|
| 320 |
+
|
| 321 |
+
def print_summary(self, summary: Dict[str, Any]):
|
| 322 |
+
"""Print test summary to console."""
|
| 323 |
+
print("\n" + "="*60)
|
| 324 |
+
print("🏁 DATASET TESTING SUMMARY")
|
| 325 |
+
print("="*60)
|
| 326 |
+
|
| 327 |
+
ts = summary["test_summary"]
|
| 328 |
+
print(f"Datasets tested: {ts['total_datasets']}")
|
| 329 |
+
print(f"Successful datasets: {ts['successful_datasets']}")
|
| 330 |
+
print(f"Failed datasets: {ts['failed_datasets']}")
|
| 331 |
+
print(f"Total samples: {ts['total_samples_tested']}")
|
| 332 |
+
print(f"Overall success rate: {ts['overall_success_rate']}%")
|
| 333 |
+
print(f"Test duration: {ts['total_test_time_seconds']}s")
|
| 334 |
+
|
| 335 |
+
pm = summary["performance_metrics"]
|
| 336 |
+
if pm["avg_latency_ms"]:
|
| 337 |
+
print(f"\nPerformance:")
|
| 338 |
+
print(f" Avg latency: {pm['avg_latency_ms']}ms")
|
| 339 |
+
print(f" Median latency: {pm['median_latency_ms']}ms")
|
| 340 |
+
print(f" Min latency: {pm['min_latency_ms']}ms")
|
| 341 |
+
print(f" Max latency: {pm['max_latency_ms']}ms")
|
| 342 |
+
print(f" Requests/sec: {pm['requests_per_second']}")
|
| 343 |
+
|
| 344 |
+
print(f"\nCategory breakdown:")
|
| 345 |
+
for category, data in summary["category_breakdown"].items():
|
| 346 |
+
print(f" {category}: {data['count']} datasets, {data['avg_success_rate']}% avg success")
|
| 347 |
+
|
| 348 |
+
if summary["failed_datasets"]:
|
| 349 |
+
print(f"\nFailed datasets: {', '.join(summary['failed_datasets'])}")
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def main():
|
| 353 |
+
parser = argparse.ArgumentParser(description="Test PyArrow datasets against ML inference service")
|
| 354 |
+
parser.add_argument("--base-url", default="http://127.0.0.1:8000", help="Base URL of the API")
|
| 355 |
+
parser.add_argument("--datasets-dir", default="test_datasets", help="Directory containing datasets")
|
| 356 |
+
parser.add_argument("--max-samples", type=int, help="Max samples per dataset to test")
|
| 357 |
+
parser.add_argument("--category", help="Filter datasets by category (standard, edge_case, performance, model_comparison)")
|
| 358 |
+
parser.add_argument("--quick", action="store_true", help="Quick test with max 5 samples per dataset")
|
| 359 |
+
|
| 360 |
+
args = parser.parse_args()
|
| 361 |
+
|
| 362 |
+
tester = DatasetTester(args.base_url, args.datasets_dir)
|
| 363 |
+
|
| 364 |
+
max_samples = args.max_samples
|
| 365 |
+
if args.quick:
|
| 366 |
+
max_samples = 5
|
| 367 |
+
|
| 368 |
+
results = tester.test_all_datasets(max_samples, args.category)
|
| 369 |
+
|
| 370 |
+
if "error" not in results:
|
| 371 |
+
tester.print_summary(results)
|
| 372 |
+
|
| 373 |
+
if results["test_summary"]["overall_success_rate"] > 90:
|
| 374 |
+
print("\n🎉 Excellent! API is working great with the datasets!")
|
| 375 |
+
elif results["test_summary"]["overall_success_rate"] > 70:
|
| 376 |
+
print("\n👍 Good! API works well, minor issues detected.")
|
| 377 |
+
else:
|
| 378 |
+
print("\n⚠️ Warning: Several issues detected. Check the detailed results.")
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
if __name__ == "__main__":
|
| 382 |
+
main()
|
submit.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export DYFF_API_TOKEN=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiJiODBkZGY4NmUyNzk0MzI3OTNkOTlmNTk2MGE1MWRlZSIsInN1YiI6IkFjY291bnQvN2IxZWZiOWMwMDMyNGNkYzlhNDEyY2M4YmQ4ZjcwZTIiLCJuYmYiOjE3NjUzMTg1NjMsImV4cCI6MTc5Njg1NDU2MywiaWF0IjoxNzY1MzE4NTYzLCJncmFudHMiOlt7InJlc291cmNlcyI6WyIqIl0sImZ1bmN0aW9ucyI6WyJjb25zdW1lIiwiZGF0YSIsImdldCIsInF1ZXJ5Il0sImFjY291bnRzIjpbInB1YmxpYyJdLCJlbnRpdGllcyI6W119LHsicmVzb3VyY2VzIjpbImNoYWxsZW5nZXMiXSwiZnVuY3Rpb25zIjpbImdldCIsInF1ZXJ5Il0sImFjY291bnRzIjpbXSwiZW50aXRpZXMiOlsiZGM1MDlhOGM3NzFiNDkyYjkwYzQzMDEyZmRlOWEwNGYiXX0seyJyZXNvdXJjZXMiOlsidGVhbXMiXSwiZnVuY3Rpb25zIjpbIioiXSwiYWNjb3VudHMiOltdLCJlbnRpdGllcyI6WyIzYWUzMzdjMTY2MGQ0YjczOGM0NDNjMmMzN2ZjMWJjYSJdfSx7InJlc291cmNlcyI6WyJhcnRpZmFjdHMiLCJpbmZlcmVuY2VzZXJ2aWNlcyIsIm1vZGVscyIsInN1Ym1pc3Npb25zIl0sImZ1bmN0aW9ucyI6WyIqIl0sImFjY291bnRzIjpbIjdiMWVmYjljMDAzMjRjZGM5YTQxMmNjOGJkOGY3MGUyIl0sImVudGl0aWVzIjpbXX0seyJyZXNvdXJjZXMiOlsiKiJdLCJmdW5jdGlvbnMiOlsiY29uc3VtZSIsImRhdGEiLCJnZXQiLCJxdWVyeSJdLCJhY2NvdW50cyI6WyI3YjFlZmI5YzAwMzI0Y2RjOWE0MTJjYzhiZDhmNzBlMiJdLCJlbnRpdGllcyI6W119XSwic2VjcmV0IjoiM3p3MjFHUFI0WGRpenpIVTdnT0ZBdyJ9.bR71zVzqrRRsK_FZCYMGmTYCsQekz6LfmzM-EoyBOW0
|
| 2 |
+
venv/bin/python challenge-cli.py submit --team 3ae337c1660d4b738c443c2c37fc1bca --account Account/7b1efb9c00324cdc9a412cc8bd8f70e2 --task 567316f19d37490b97aa647e7017ef44 --service sida
|
test_main.http
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Test Prediction Endpoint
|
| 2 |
+
# Works with any model configured at startup (default: ResNet-18)
|
| 3 |
+
|
| 4 |
+
POST http://127.0.0.1:8000/predict
|
| 5 |
+
Content-Type: application/json
|
| 6 |
+
|
| 7 |
+
{
|
| 8 |
+
"image": {
|
| 9 |
+
"mediaType": "image/jpeg",
|
| 10 |
+
"data": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wCEAAkGBxMTEhUTExIWFhUVFhoYGBgYGBgYGBgXGBUYFxUYGBcYHSggGBolGxUXITEhJSorLi4uFx8zODMtNygtLisBCgoKDg0OGxAQGy0mHyUtLS0tMjAtLS8tLS0tLy0tLS01LS8vLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLf/AABEIAKoBKQMBIgACEQEDEQH/xAAbAAACAwEBAQAAAAAAAAAAAAAEBQIDBgEAB//EAEMQAAECAwYDBQUGBAUDBQAAAAECEQADIQQFEjFBUWFxkRMigaGxMlLB0fAGFBVCkuFigqLxI1Nyk9IHFkMzsrPC4v/EABkBAAMBAQEAAAAAAAAAAAAAAAECAwQABf/EAC4RAAIBAwMEAAUEAwEBAAAAAAABAgMRIRIxQQQTIlEUYXGRsTJCofAjgdHhBf/aAAwDAQACEQMRAD8Aw5s1A0U9kNocCynNNQ0DLs6tmEC63NOiSlpYJ2HCPJlNpDGXZ2rBgsQUlwKjOJdzNjQ6D03QtkyH08oum2cJGQi8yCPDrE2ejRotg89zd7oEkqAIoOkNhaAQE4RXWkLTZ1Pk4i2y5784TQi3flawzTYhsOkXy7MNh0gizFwKNBQlxS5ltkETZhsOgi5NmGw6CCBLbWCESHaBc7SwEWQbDpHZlhSQRQcWhl2MCrtKUKGJgPFx5QssrcvStGWUmI1XEp3IcHhExcWQYeNI1UmcFB0mhyjxDcYwOjWk9z2Y9bQgtjMKu6UipFdqExXaJiQHCR4t8Ic2iypxYi/wim1XbiY67fOO+GUWtTuxo9c6ibirL2ZS3SysvhFdhACrEraPoCLAQ1AeAELrXdgUXAz+sodudNaUkCn2K0nJt/7MYLPwiwWfhGgXdvARxFgHDpGqnGTyzLWnTV1ERfd+HpEhZeHpD37kNh0ixNjGw6RrUTyZVBALJw9IvRd/DzEPkWQbDpF6LKNh0htiblcRy7u/h80xb+Hjb0h6iQNh0iX3bl0g6iTiIfuHD0jyrAPd9Ieqs/LpEV2fl0jtQNJnV2Lh6RUbHXL0jQLkDh0ildmrp0g3BpEa7INvSK1WYbekPV2Xl0ildl5dI64bCc2UbekV/dht6Q6VZqftEBZuHlHINhKqzDaBewG0aM2T6aAfu300cwJASJymppB9ktYyUIVS1NR4mCTlGNHozWTUWaUkhwxHCDJNnABDZ8PLlGPkWlSS4Lcobyr6mEhyB4CFlSTGhXqQ5uhjOlJByPygKcEDIkHQ/N4eyLRKXXEAfntHLTIkqVgWCDnl8YKbW5F2bwI0CoCWL1NKdXhpJugZ5P6xcLtTKLoDvXEch4HOI2K/CtfZplYqs4LPxY5R177AsRstjmhR71DkaMPAwwtM5EpJMw5DqWp1hhMUhIYqCSqgBLExmL+uObNI7N1M7qJoBpXrAGVjky/5KsgrPLgc+BiVqtgWlAQCSMhw/M+1YrsH2KmYhjUGIdx5jnGl+zlx9kCpQ7xpvT94LaCUXTJUZbkkYXehd826bQDOXMWzoolRAUB7QNXz+Gkay19nKQ61JQjV9eAGpppFCLxshT/6yGpQltaUPEQEwGYu+7FKmv2ndDsRlUsQDGuFkYeEQVZbNZpa7QAMOalJ7z4jn56QnsP2yQcRmIIr3QPdrmSz6c+Edl7HYG8y7Bm3ebWK50pMtJVMUBuTlySIJkX5JKEqMxCSfyk1o/TLWMFfX2hmTz3gAgZJGQrmXzPGDGOQuTasaFd+2erEnCMsiasyd4BvO14UCYO6lTcSX32yeMcu0h8oqm2t+W0NpQY6th9OvNSi9AOIL+sdl20anyMZ2Vajx6xaZyuPUw6kkh+1J8D025O/kYnJtD7dDCIKL6+cMLGo7mGuuAdvhodygTF7GF8qcRv1ghNp4HTWDcjKkHS0RahML/vJaj668YNlzOefwjrk3AsVLMeMqJ9rTKOLUdjHC2BJkqIGRWClJOxiKJKn16GGTFaKOxilciGIknj0Me+7E6HpHXOsKjIeJCzfTw6s93l2LjiU5UgwWRCM1IOlQB6mElWhDdl6fS1KmyEKbrUfy67iFv4cv3fMfOH96XskA4UgUB7qwPzcBGc/Fuf6/wBozrqZS2R6a/8AmU0vJu/9+Rip1oxF2YxZKKgHYtvEJNNQecNJgRMQEpUArV6B9IDnosrYJxoKtd6s+vYImYOESC6wFNQpJYxwTYsmmYpQknZmgu219mXfw02qIMt17qUkIQmvvByrOgB0jNSpu5g2TawmteBByjrISzDECaokl6ZkuW59YayLaZSQqWAfeWdTqG1MJrbfJmJSmjJq7VLgO/SAjaTA3Gs0M7dbitZWrM+HlBtzXsqScSTXIg1BHEPWM+pZOcSwGGxawLPc+m3X9p5cxSUKwSw3tElnbdxhru8aeStKgClSSDkQQX3bePiMte8N5N9qlpTgLKANSKir0MSlT9Dxb5Dft7eOO04KYZQwgjUlioHkaeBjM9qSWakSt9sVNWVqqo5nc5aQPKVWHUtKsOqTkzUSr3TLsi7KHV2jKUrEQyu6WAaoZLHfwqgM1uMUrcxxElR0JibqpGiHSSfBbOtyiGyEU9sTSLU2RSvymCLFZ5aS6yS2QTqeY0iU+oSWDVT6J3ysFRsC2ds46LuUdINVe89+7LASKDuOQNKkRfOtE7D3+7i91OXAloxSrVebHo06FFbJipFiINRBiLKT/eJ2CzlRNSaah/jDuXYhuOkOq9lkWVBXwKE2SuXnBdnsphqixp3T0EXJsoH5k+Xzh41pSZJ04RQDKs5iZlQTMkbK6f3ir7qr3jmNDvzjTGczLKNNlQkq4Qxs0vff4RXKs5bM5H1hnIsZfPXbhFYzMdeMUsFfYCLAkQbLsf00W/cPrDDazHouL0JEXolwaiyAbdGiJnoScknxGsTl1MYlqfSTnsQRZNx5x6b2aNxVjQnR4WW+9yMnyOS29BCW1XiVHNQqD/6hOjRDu1Kn6cI9On0NOC8x/bb1GE10907xnbdeZJoR7Xu8IEnTFHVXUwJMB458Y6nSSeTa3pXieWomhbKF/ZQahJOp84F7NW584u4+iDqW3YjTKpoYjMs6hXfjB6btWQCx5/3ERN3L3HWB3EjzFQk9mAAGIrlQWuxrBZuhEVrlqGhju4hn08wTLSJpMWqA1cHqIj2UdrO7ByJBcR7OJYYOsXsssTOIiQmmKkiJEQXMMaBaFcYJRZphySS8UyJpTklJ5gxbJmzS+ENi2f5tEJ1JcG6l09Pm7+gQq5ptGzOlAY6izS5ZZQmLV7oSBVst4gi7p6iF1cUCipj4EmGdiuxQqucl3qy3roMqHkYyTrNLyl9jZToxviFimRZEGvZTArRD18SoU84cSrKEh3QkalQdjqyiO8fDSIi0KlDNCf8AUsk8SzH1iKbxQ7hBJOaiXHIA6dOUY5dyptsa04wLxbEthBKyNQwB8AwIiuyS5jv93wbLIDB9e9r4GDEXkGGGnBLJPj3fjEzeS1ZqNN8Pq0TjTnwvuLKoTk2KYoP95UAMwC/9TR1UgJDqWFPus1/lLwDNnKUXxEeMUT0Pq5EOumb3YnczgITbUk90N/K3wEW947QHd8glRzy2fWH8mzcPKNcIQhsQqyk3lgaLOdvOCJdm4ecGosx28oJlWQ/Qi6qsyShFAMuxjbzgmXZRtqNYPlWM/SYIFk+sMPrZnlJcC+XZBtvrxg+VZztrvF8uQw66cYqtiyAQHdtOcF1VFXIqDquxYCE5wPaLySMz/SreFk4zN1/1QHaJUz+PL+LeM0qjmz0aXSQgF2m9E+9v+Uwtn20nUdIiqyrOi/0mJpsCvdV+kw0IpF5aUtwNcpKs/jFYsiB9GGX4Ws+8P5TF8q4Vn85/QfnGmM8Gacobt3FZkp+nipVkB0840ku5CKu/8n7wSi6hsP0QVNIk6vzMim7VaDzEL/uC9vMR9LTZEj8g/SPlCb7un3R0EHvEXM+c2FUxGEpOTEO8Ml30A5nSpZIbJIdlKbP6yhR9mb8lJkNObuszt0AMJ/tPeAVaFdmRgDANkW1Fd3rGWOqU3Frbk1VZ0lTUl9jdzZkgueyALcT5GFc/CGZmJAAIZySwoDErivaTNnKRMKQOzSoKUcIBYlVSc6geEZ/7Q3oDaFJlkYEqABFRTDUF9wYWDk5abBmqUY6kzS9lKI70kPuFK9IpNyy1eyseX7kQTJnSVzxKdNZSVgghsRfECXzZojbkIRapVncPMBL7e5XckGnKFVbNvlcq6cd7823FFoukgsK/XKKVXadx6RprJYkrmzJSZgJlpSS71xO7DgwrxiJsQ7fsKYuzx5aO0N8SlgXsX5RlzYiDxi1CFJqAAd2BPnGpNzHSvIExSu7iPytzT84D6mEi8Ona5M0EKNMzwSPhBNnBBZSlJ4YRDU2Ze5+uERFmO3n+0c5pjxjpe7KJKUEe1lulPkKxcpKSw7x8UgfpaLUWfkD4GL5Vn5xJ6dyuqTB+yQP/ABjm7egiXZqOSQkDb94Yy7EOJ8IPs9l2ST1+BiMq0Yhs+RUmQTqfB/nBybtQ3t12NPM5Q0lyfeSG4GvDNUGy5LCiW8B/yNYzS6l8E5zSE9kupGalYtggYusFGyyxVsPMEK/pL9YbGS4qtXgUt5GKDYEh3HUv6RKVST3ZFVE3uLZIQSQkF9z8yTDGVLLRGXISMkjwEFo5RtoW04Fqy9FkmUX8YNlSoGlGLu1LUeNaZhmmwoIimdaEj83kflAmFfvq6n5xEWRRzV1eG1ehVTXLJqtg38jFPbklw0Eou/iOkFpsg2HSBbljOcVhC/t1cI92i9h9eMMvu42HQQDfN4S7NLC1MXmIQ1PzEP0SSekMvkS1E5Sl8IvQkwPeF5ypIlqUQ01YQDTUEvypCu134E3nJs7gJMhb/wCpRlqS+mST9GjJNitmiCYHm25KVhBLEh/MJHFyTAH2kv1EiWFBSSoqSAMQ94O7HJqPlWpEYL/qNfahaWlKDJCA4OIEM5Aagqa8oaMXIXk+qMYXXjeqZK5SVFhMKn4BKCX6gDxA1hJcX2pSbJLUopxBABGMAummR5fWUfO/t3fhnT3SaIDJdlMXeh2y8RBhTbYXg+x2S29oVhLdxQDuC7oSsHh7XlCx1RlfsF9qUgTjOVVRCwVKAySElLqNS4cAZAx7/veV7p/XHODTsdufHQmOhMNU3cWyiYu07RbWiSpsUlEdww3F2K2iYupW0LrQ/bYpSohi8XzLZMMxMzEcSQADSjUEMfwlW0d/CFbQrlG9x1GaRVdd8zZE1U1Ku8oMpw4IO7xcm95vaqnEurBhJ4FOH0iQudW0XJuhexictDyVjrWBlL+2K2kBSB/hPiId19wpS9eJeHU37YyGThlKqAVB6ineAJO7eEZhNzLP5TBAuJZ0jPKjSfBeFWquTbWW2WWYpKUzgSpRSBlUAHUcQHg+TdiFpC0EKSqoLpbyMYBNxL2MHWe6JoyKhpQkZVHmYzS6dftkaFXmbH8HGxjv4YBv0hbd8+1oThCiRhYYg7cQTWHVkvif/wCSUk8g2/PhEJUqvsPfnwVy7E2QEEy7Fw9YZ2S3IXmhSNnDv0hkiWDlEZQnyRn1UluhRIsja+vxg6VZ+MHJk8IsTL4QY0W9zJOvcDMvjFE2zgw4RLjk2VSkafg5OGtEo17MQfdhFiZEHrkxDA0dTi0rGjvXBhKixMqIzrUlOhPKF1pvZY9mWdc66UyjTGEmFKUhuhERtM9MtClrUAlIck6CMpeFstMwNiKB/CCH3rnCKdd0wu5WXDH2qh3Y71DxeNNndp8s2ls+1NklKwqnAl2LAlqO5YZct4Ro/wCoEvtS6T2dBSp7ruQMquOkZw3GrY/pMQmXGpsj+gxZU0DtpBdv/wCoK+3lTUSwyJakqSciVEEts2FPQxl74+0E+0Yu0UGVM7RgG72EIDcGEMplyK2P6DA825lNkf8AbMVUUhHH5C+9L8nTky0TJjpll0hgGO7gOTziidfE1U9M8realmUQNEhIozZAdILXc6joofyGKlXSt/ZV+gxRRQjT9A943zNm4cZBYN7IELpk/l0hv+EK2V+gxFdyK91X+2YooxJtTFcq2KbDSvCK5q4ci5le6r/bMVrudb5K/QYZJAakJRNPCKsf00O1XKvZX+2YD/DFbK/QYNkTambtCpDJ7qMhmeHOLEqle4nqP+UZWUSwPZpelRXLh9ZRei1N/wCM/pEeG1LhnsJx5RqCuSPyJ+vGJCdJ9xPUfOMwbfoEp/mGrR4Ty1EpfavzELpl7DePCNaJ0r/LH14x42mV/lj9JMZNE5fuS+RB+cUrt6sTdjLU2oA+Md25eznKK4NmLxkDNKRzYepif4vZx+VHKh9IxybSvMJlo/lST0Ecl3wx709XhLQPWC6UhdcTdS71le6j64GLk3vK2SeQf0EYtN4YgCkzCd+78IIVaJgFZiv6Q3nEnB+/5HsnsjYpvhGYQP0EeqYtTfKRmEeX/GMEu+QAxExR5GsTk3yMlBXiTlAdOYPA3qb+Gw8A/oIuRfnL+gesfPZN8SnoS2tEv5xem2k1lomKGrhKa0Pz6wHCa5OtTexv/wDuAjUf0x0faBW48SkRh02yaR7BT+k+hETlXo2eNxsk/ExKWtc/yN2qfo3Qv4+95hvKJi+1ak/XKPn0y8VBWJMuYRqWB8qfQgk24kPgmDoPKJt1fZ3Yp+jefjR1V6wIq/Ve8esY+VbD+VJB4gPEROXWp6o+CTGno9cpNSY0enprg1y75V73nFS73Vv6RkV2guAZiknbAFcq0aIptuYxTCRniCUjoTHo9tDaaadrfg1/4wrL5RA3qdvL9ox8y8xlWmZdB8gaxGVeKQCSZhzoEHTZtIdUzv8AEvRr/wAV3w+Xyjn4oNh5RlpFtJGLs5gG6gB5Kyga0W8g1mtwYK8wqCoXwc+2lf8Av8mu/FkuzDoflEF3mk0+BjLpthNHWTuHSPi0cNqdTEKp9VJEFROtD0jSKtaYqXORxhCbakaN/NHUWknceMGzKRVO+y+w77RHHzjgXLfXzhP95zd6cYslz3rUeMdlD6aT4GRmSuPnEwqUd/OFSpvPrHEzq5keMdZh0U/SGyjK2PnER2Wx84RqtR3V+qJCeWd1fqMdZ+zlGk+Byex4+cLOzlbHziEy0NqfBULfvX+r9UMk/Ys6dL0ZIWxQyrTeLBal69IhKkJcd6jfD5xebRhYhIfjyMT5wjw0na7ZMXiRQp+ucdC5inKEkDkdn+ucULt4OaU8+r7xZZrSTQZH6Hp5wko2zYeMtWNR4KmULqqPrXlBGJRGY55Goi6RIQc1e0aas523j0+6kl8KwSOeVM+HzhZVE3kpGjJLx/ILKQoHEC/jBcu0rz7hBOvjxgKfd6gHxUA9Sz+kUJJQWJB8Y5pSyBNxxZoZfeJhOaQ21ItwrA9tLcVE604QnmzgajOIDGugVTjlw8I7tpgda2Ms0EuclJZS3JDgvwGjHcxNFsSNH4kPV4SSrsmPmH5wZOu9ftEuNSCOGY5vCypxezKQqSX7bDgTyWKVpHBgPNmip11xTta5noBxhQuZLQQS6nBOeVWgj8aQAwQw2/fwiPbfCLdyPLt/sZWeyqaigdWdiYsCVADvl6ZVH75QpTfcshsJH8Tl/XhHpiisYkKpsebFzzIhHCXI6lHdMZTLTPTkkq8QGiUtU13Vm+poOJ4QhwzCogE0+vjFk2yziD3nEF0rehVVzyP13tLTmoE/wgepzio3lJPsqUkqo4YtzDRnvw3+J/D1ia7qAFZgJ2HzMaOkpU4yeRZ1qlsRG6J05NUWhKgdHIpuQdeFYOlqnBGIS0zNyApRrpUcNIVWW85VnHcl9/UqL6vTYwDO+1NoJJxnhwrG3S5vCF70Ka8pP82+5ok2tnM1SZb/AMAfhrtwgebbpJHcnYXZ2SEqPiPCEh+0S1pwzQFjiPj9ZCOJNmWaOgmjZDqfrKD2tOWL8UpYg0/rdP8A4OVSEKS6lTa/xAvxr9VidkkywD2ZNNVAO+wJ56CFGBKa41KDigOgj1qvJRDSkYQKd0acTvxhcvCYdcF5SWfux5PmMaHF/N8B9VgY29CQ6mfQMT8TCGVZZkwuVgEnJ6wzlXM5AKiQMz40A2o/SGUYrdgdarP9MfuEm1ylVDOdMuUeUvQM+0DruZOeJhpn8vSLkXaEd5CqjfOtI7VBcjqNV7xS+gR2E1nUgAAO7jIByWfaB1WutMJ8DHBeExL4ypSdRoUjMeIpFtmvCQst2AHgnSC9Vr2v9Du5G9k7P5/+BUjCpJ38YrmSpnsoQC31qeMBT7UUTJaUkgLOlAK69YYJtZSyiol9jX6pAeqI0ZwnvhrewFNVRlsFn2QHqPPjFaitKQWDGg+gYYXdYwpJmKUlWEsxqchk/OOTpImKKQQkJL8NqCG1q9kBwla/2BkzWFWbkYG+8J3HQxJSq5FtdoGxJ2HlDYJOvISWezlTMdhrR96cDF05JoFEUptlE7HbUppgHluD84vtlmxgYSBmaDyprnGVzle1sGVUo6bxd2BBKRo+tdo8be1EjyiKLGpiS4wjIg1+UVT7IUkBi5Zs6uzN1jrJiuU4rCsdm21SjyyrxixFrmJNaOIjJksRiHWj8oItWApISAKv5Gg+vhDtRatYVOd9Vy+TOWujUyD70o+WsGSruCmVMoQ3doHFCC7sz7NCpFrCEsKceORMQn3ip3JzGnn5wqptbFnVj+7JoBckpYSy2wkBVQp2YEUGdM4ptsrsckhSFZHNw+41YfTwkk3gt6Nrm2uecGKvHCkJUHoR7RwnUHiQ/nCSgykasLXSsyybbCr2UgNs+2bxFCphDVKX5OeeekUotYNEpDk15Elhs0XLmzBUghyQ+jUSacH84S1sI7VfLZZ+HKUKjL5hxuzGIS7rJFVMHbwicq1E0xF2y3ap8aeW9IhNvA5ca+EC8w6aW4bZ7slpHeLnUCtHz+EFpWEuE5El/HhzIhYm1qUXBAo/pX+3GCpE6gJfXhlmD4t55xKSk9y8HFYiW/iEsA/4ZDsDtw8jFsudiUCkNTfcf2iaES3cgU5bvkenXjFU+1Swe6APD639PCeHhJlVfdtWCLQZYAxHpxETVd0hdQshxqxY6NTjGVtVqCifKrx6VeC05aZMNeY5Rs6Sg75Znq9VBYawN7bcqkg4ZgUSQANcnJPllvCedY1DyLa1D5Qys98Es4Ph8R4mLkIQsE4iC54B3GVeJjQ5zg8meVKnVzTM+hwQAHPJz0hnc9nmziVAJCQQTiAD1zDhj46xcmRKSpy9D4M9c86U5R6dfobClDCgYACgNBxoWhu7qdooFKgoZqysvS5GiblkIUP8U0oRpic1yJoNOAiM+xWcuoKIO2LfLTb165mfbFhzoTvWh8opN4E6xVU78jvqqEXp0DW22lcs0A4ED64QEu2KDkEseNOXnHFXiFe3VhFlkCVqSkJcbZk5A8svCOUVHdGepUc5eEvodkWxY1LaudOUMLJayaACKvuiW9jMBjrmwPB9orl2RUouVu22r0zeFThIvDvQtyvwjTS7vnLQQlDhSWPeSKKGjmOTfs5JlygqapaV4mUAQQASWySdG11hR/3UpAwjHtRW1IHmfaRUwkKMwjOqqUECNOpf0is+o6d/N/MutFmmoUBNThSrKoJKd6E8IslTbOgsFnGMwQqgzzwttHDOXNBJKlEBkuSojgIW4CJhBQokCtC+mcPvuRk3Td4rD5ZdZDLShS0+0D3XBbxpxiSr3Qw7wxfmGFVICNkmZ99AGaWUMXFvrKHNguATEgqKUOHdSc/Es8GUoJXYlONaUrQR6xze1BlZlRyFMmVmaD2Yj+Bq9w/qT84Z3YuTJFUoUtyQruggEAMCa79Y7+Mo2H60xmdSSfisHofDwkk5tXMMZRSyq5Aqpk5I+n1eG9jVTF7oxB+B2gWRNbulRKSzk1O9a11iAnglTUchqlhvU6ROaMFJ6M+wiTOUrEFOS702Bb5w/tEhBUkqBKqTKHYJc58ATzfWF1lnABSZgTLIYggOXBdidDTxjiLwQoDchjt7Ayfd68hEZXbujVCyXk7iu87PMM3DTvKITsyS2goKQBOCgSkioLEEEeuUav76hRlBVGck8SAH5sBrp4Qjt6DjChValFRSBkfaSATU00rpWtLUqnBkr0UvJPkWoRiqSenX4dYNRZsSHb2AScy44cHIHMvwjsyUQvHgo7qGgr3gT4n9od2BEsDCafkUl83Se9UuAVFJ4FNKUg1KllcWlSu7CuVYFLQlgAaCgFQaBXGjvqOzVR6RdZrmUtQQ7AYiXYNhzetTkW4HaGlltQT7RGVDscTKqfHziSbWgFIo63LijBTBuqTlnEe5LNkalQji7BLHYhhwLSwYEkb0BPIpy3Lcoum5AEvVzpnQmn+kRG8pxKxhd1JGvCtcsh4Qmnz1FVSK6U+HOFScsjNxhgYlA2qfKrv5esD2yUCXAOp6s7vxc8o7ZrUAxI8csjWvjHJtrBDDQN4vsKcPCOV0xZOLRTMThLYvDXx216cY9KtZq78urvxf4xZZ5Yd3ajeH03SLLRJDLpn5OXo2XtK6jaHvF4FSksoGFrKqO1R1Km+MCTJpdnrBapTN3agnU+Bo20CTZRBA+q/Q6Q6UeCMnPkqJjqpxjix4/WvSPJlkkAaxoo7mare2Cci2KSXBZtdoMTeaiGJfy0aIi7y4QWL1JTWgr3S3u8IlarsSlScKlFL1cVAfccI0OzWToqrDY7KtAOYcnKpJenX94MRdoUArIsOAHmXf45RXZ7JhwlOpJdy9DRxpl/UXjgvc6ihJ3bfXPPzjJJO/gaYuKX+UBn2JYJdyBmQ5GcVCzFRPdbNtuXR4dIvVLEKYuK8/r1ipFqST3aZM5bgz+MOqk/QjoU2/1AlnulZFEkmn9SgB5ehhhd13TETC7AHCHerLq4fNtYKss8JV3gcScLDINuSztUnMZcostVuSSWZ3/YGgzbLgIDrTl42Lw6SnHyvsUXhZhgxOcRS+FxzAIGbJCiSNtc4RzZ60ird4AimjlvSDbVaaqc1Hu7voYoBSUlzWgFBo3TKKU04rJOtPVK0GL0uS+p+jDSw3aVEFSS3Q8Dyz6RRZZIWoUYFQDvUOWJ40rpGtu6XLSgSu0ol++SHLknydvCDVqaVgPSdL3H5bC24LFaSsGVgYLq5rwjXSrLJCiZ4Pbkd/CVYdGZqeyEwDfloTKCClQ9kksWybMiMlb78UpRIL5fnJ0EQSnVyjdKVLpvFu5qL6mCWwmUUoOnWj6tFUgTbQgIl4TgD1pTKI3TasSSpYxYTmouwYanIRGdakhalJUzn8pb00gW4NFv33w/uZ+8rCsTRQNgcsa0xQD2Y2PnGvvK9pakKUMGP2QAatu4HE04Rlfvi/8xX14xpgnJGCtCNOWHe/99i3EpgSaEZcK6RYhWEeNcsiND4QPNNTz+EMrMHUp69xH/sTEbXdjCgNdoJNSTRn2bLPlEEWgt0PL6YdIEUYOQkfdSpqicA+rYCWfZ4GlIXW2FWN1lKUkBRoHoK5V0OUXSAQtLkgjvAlmYZau+Ib5VgWznuA6gljqO8pq+A6RdOLyUE54lB9WxKLPs5MTaNMHdZDZs0hSiwepoRV1Emnj8I9aZmqQA4pq5FD4/tCtJ/xE/6k+ZDxK8Syqbn1ELoyP3LJs4u0u7u70+PwiCrSWSH9kU2qf3MAg/XhE5Jr4fAxTSjN3Gx7Z5y1oJBHcSqhNcIGJTJ1oGhbMUSSczn8fSIqWQlwSCCGOooTBt9qKbQSCQTLlmlKmUhz4uYRKzKSldZKl1YJGRr575e15cYrlDekUoUWVXb0Mdeo8PhDWFvyGpUd6joRrF8mYQCWOQ6d4eqRlxgeyLLgOWKFU0owFIqCi45/J4RxuUU7ZCp88sxDEE8GL1p4+cDTpgOv1lEbQs4U1NcPkFgeQA8BAhMGMRJ1Gy4qg6zz8gQzBvJn51MKgcuUEyVFs9B6xWCyLGWR3MWXJI2y8NB9ZwvmW04lMaORtQu0CS5hfM5bxeMpnAn/AOQRWK9jVZ6l44PC3KZnozecVTl4hRgBwH99YGRpyiSDn4xRoya28MkpBD50oXzHAvrQ+ceQ43Zq19Yutij2iq5rrxq9YLuI95R1CFkHYhmI2NTE5Ssrjwp6pKIQZ60DEpsS0guCxwmo7p56CrGphbNtRLkqJJ0isrJCiST3t/4jA69INNWY1WtKSVngNVKfvPSjDicmD0q9Yts1kVRbOAVUJbIZirn+0L5Ki+e0aa7pYJlAgEFC3BDg11GsNUk4op01ONV5+X5RxExGMKwBqMAGy4DlrA9pvIAkhwOQp4QAVFl10/8AsRCu0KL5xOMFctU6uUY+KsO7bfRmYQtSlJAZmTltSBOyC1dxgCHY5gBhVuMKwo7w5sZaUDr/APqKtaY4M8JutPzGFinTUy1nH3H7wDVoOEBTr2Dkd7oIIvlRcV0+MKJWU46hNDqKjKFgk1qZpqzkpKnF7f8AGyxdpcUNYH+8GBUGIxbY82VSUnk//9k="
|
| 11 |
+
}
|
| 12 |
+
}
|
upload_model.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: 2025 UL Research Institutes
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import click
|
| 9 |
+
|
| 10 |
+
from dyff.client import Client
|
| 11 |
+
from dyff.schema.platform import *
|
| 12 |
+
from dyff.schema.requests import *
|
| 13 |
+
|
| 14 |
+
from app.api.models import PredictionResponse
|
| 15 |
+
|
| 16 |
+
# ----------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
WORKDIR = Path(__file__).resolve().parent
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@click.command()
|
| 22 |
+
@click.option(
|
| 23 |
+
"--account",
|
| 24 |
+
type=str,
|
| 25 |
+
required=True,
|
| 26 |
+
help="Your account ID",
|
| 27 |
+
)
|
| 28 |
+
@click.option(
|
| 29 |
+
"--name",
|
| 30 |
+
type=str,
|
| 31 |
+
required=True,
|
| 32 |
+
help="The name of your detector model. For display and querying purposes only.",
|
| 33 |
+
)
|
| 34 |
+
@click.option(
|
| 35 |
+
"--image",
|
| 36 |
+
type=str,
|
| 37 |
+
required=True,
|
| 38 |
+
help="The Docker image to upload. Must exist in your local Docker deamon.",
|
| 39 |
+
)
|
| 40 |
+
@click.option(
|
| 41 |
+
"--endpoint",
|
| 42 |
+
type=str,
|
| 43 |
+
default="predict",
|
| 44 |
+
help="The endpoint to call on your model to make a prediction.",
|
| 45 |
+
)
|
| 46 |
+
def main(account: str, name: str, image: str, endpoint: str) -> None:
|
| 47 |
+
dyffapi = Client()
|
| 48 |
+
|
| 49 |
+
# You can set these to a known ID to skip that step
|
| 50 |
+
artifact_id = None
|
| 51 |
+
service_id = None
|
| 52 |
+
|
| 53 |
+
# Upload the image
|
| 54 |
+
if artifact_id is None:
|
| 55 |
+
# Create an Artifact resource
|
| 56 |
+
artifact = dyffapi.artifacts.create(ArtifactCreateRequest(account=account))
|
| 57 |
+
click.echo(f"artifact_id = \"{artifact.id}\"")
|
| 58 |
+
time.sleep(5)
|
| 59 |
+
# Push the image from the local Docker daemon
|
| 60 |
+
dyffapi.artifacts.push(artifact, source=f"docker-daemon:{image}")
|
| 61 |
+
time.sleep(5)
|
| 62 |
+
# Indicate that we're done pushing
|
| 63 |
+
dyffapi.artifacts.finalize(artifact.id)
|
| 64 |
+
else:
|
| 65 |
+
artifact = dyffapi.artifacts.get(artifact_id)
|
| 66 |
+
assert artifact is not None
|
| 67 |
+
|
| 68 |
+
# Create a runnable InferenceService
|
| 69 |
+
if service_id is None:
|
| 70 |
+
# Don't change this
|
| 71 |
+
service_request = InferenceServiceCreateRequest(
|
| 72 |
+
account=account,
|
| 73 |
+
name=name,
|
| 74 |
+
model=None,
|
| 75 |
+
runner=InferenceServiceRunner(
|
| 76 |
+
kind=InferenceServiceRunnerKind.CONTAINER,
|
| 77 |
+
imageRef=EntityIdentifier.of(artifact),
|
| 78 |
+
resources=ModelResources(),
|
| 79 |
+
),
|
| 80 |
+
interface=InferenceInterface(
|
| 81 |
+
endpoint=endpoint,
|
| 82 |
+
outputSchema=DataSchema.make_output_schema(PredictionResponse),
|
| 83 |
+
),
|
| 84 |
+
)
|
| 85 |
+
service = dyffapi.inferenceservices.create(service_request)
|
| 86 |
+
click.echo(f"service_id = \"{service.id}\"")
|
| 87 |
+
else:
|
| 88 |
+
service = dyffapi.inferenceservices.get(service_id)
|
| 89 |
+
assert service is not None
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
main()
|
upload_submission.sh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export DOCKER_HOST=$DOCKER_HOST
|
| 2 |
+
echo DOCKER_HOST is $DOCKER_HOST
|
| 3 |
+
export DYFF_API_TOKEN=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiJiODBkZGY4NmUyNzk0MzI3OTNkOTlmNTk2MGE1MWRlZSIsInN1YiI6IkFjY291bnQvN2IxZWZiOWMwMDMyNGNkYzlhNDEyY2M4YmQ4ZjcwZTIiLCJuYmYiOjE3NjUzMTg1NjMsImV4cCI6MTc5Njg1NDU2MywiaWF0IjoxNzY1MzE4NTYzLCJncmFudHMiOlt7InJlc291cmNlcyI6WyIqIl0sImZ1bmN0aW9ucyI6WyJjb25zdW1lIiwiZGF0YSIsImdldCIsInF1ZXJ5Il0sImFjY291bnRzIjpbInB1YmxpYyJdLCJlbnRpdGllcyI6W119LHsicmVzb3VyY2VzIjpbImNoYWxsZW5nZXMiXSwiZnVuY3Rpb25zIjpbImdldCIsInF1ZXJ5Il0sImFjY291bnRzIjpbXSwiZW50aXRpZXMiOlsiZGM1MDlhOGM3NzFiNDkyYjkwYzQzMDEyZmRlOWEwNGYiXX0seyJyZXNvdXJjZXMiOlsidGVhbXMiXSwiZnVuY3Rpb25zIjpbIioiXSwiYWNjb3VudHMiOltdLCJlbnRpdGllcyI6WyIzYWUzMzdjMTY2MGQ0YjczOGM0NDNjMmMzN2ZjMWJjYSJdfSx7InJlc291cmNlcyI6WyJhcnRpZmFjdHMiLCJpbmZlcmVuY2VzZXJ2aWNlcyIsIm1vZGVscyIsInN1Ym1pc3Npb25zIl0sImZ1bmN0aW9ucyI6WyIqIl0sImFjY291bnRzIjpbIjdiMWVmYjljMDAzMjRjZGM5YTQxMmNjOGJkOGY3MGUyIl0sImVudGl0aWVzIjpbXX0seyJyZXNvdXJjZXMiOlsiKiJdLCJmdW5jdGlvbnMiOlsiY29uc3VtZSIsImRhdGEiLCJnZXQiLCJxdWVyeSJdLCJhY2NvdW50cyI6WyI3YjFlZmI5YzAwMzI0Y2RjOWE0MTJjYzhiZDhmNzBlMiJdLCJlbnRpdGllcyI6W119XSwic2VjcmV0IjoiM3p3MjFHUFI0WGRpenpIVTdnT0ZBdyJ9.bR71zVzqrRRsK_FZCYMGmTYCsQekz6LfmzM-EoyBOW0
|
| 4 |
+
/media/NAS/USERS/shahid/WACV/example-submission/venv/bin/python challenge-cli.py upload-submission --name "dino" --account 7b1efb9c00324cdc9a412cc8bd8f70e2 --image safe-challenge-2025/example-submission:latest
|
| 5 |
+
#
|
utils/My_version.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
| 2 |
+
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
|
| 3 |
+
#
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
# This work should only be used for nonprofit purposes.
|
| 6 |
+
#
|
| 7 |
+
# By downloading and/or using any of these files, you implicitly agree to all the
|
| 8 |
+
# terms of the license, as specified in the document LICENSE.txt
|
| 9 |
+
# (included in this package) and online at
|
| 10 |
+
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Created in September 2022
|
| 14 |
+
@author: fabrizio.guillaro
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from torch.nn import functional as F
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CrossEntropy(nn.Module):
|
| 24 |
+
def __init__(self, ignore_label=-1, weight=None):
|
| 25 |
+
super(CrossEntropy, self).__init__()
|
| 26 |
+
self.ignore_label = ignore_label
|
| 27 |
+
self.criterion = nn.CrossEntropyLoss(weight=weight,
|
| 28 |
+
ignore_index=ignore_label)
|
| 29 |
+
|
| 30 |
+
def forward(self, score, target):
|
| 31 |
+
ph, pw = score.size(2), score.size(3)
|
| 32 |
+
h, w = target.size(1), target.size(2)
|
| 33 |
+
if ph != h or pw != w:
|
| 34 |
+
score = F.upsample(
|
| 35 |
+
input=score, size=(h, w), mode='bilinear')
|
| 36 |
+
|
| 37 |
+
loss = self.criterion(score, target)
|
| 38 |
+
return loss
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DiceLoss(nn.Module):
|
| 43 |
+
def __init__(self, ignore_label=-1, smooth=1, exponent=2): #because padding adds -1 to the targets
|
| 44 |
+
super(DiceLoss, self).__init__()
|
| 45 |
+
self.ignore_index = ignore_label
|
| 46 |
+
self.smooth = smooth
|
| 47 |
+
self.exponent = exponent
|
| 48 |
+
|
| 49 |
+
def dice_loss(self, pred, target, valid_mask):
|
| 50 |
+
assert pred.shape[0] == target.shape[0]
|
| 51 |
+
total_loss = 0
|
| 52 |
+
num_classes = pred.shape[1]
|
| 53 |
+
for i in range(num_classes):
|
| 54 |
+
if i != self.ignore_index:
|
| 55 |
+
dice_loss = self.binary_dice_loss(
|
| 56 |
+
pred[:, i],
|
| 57 |
+
target[..., i],
|
| 58 |
+
valid_mask=valid_mask,)
|
| 59 |
+
total_loss += dice_loss
|
| 60 |
+
return total_loss / num_classes
|
| 61 |
+
|
| 62 |
+
def binary_dice_loss(self, pred, target, valid_mask):
|
| 63 |
+
assert pred.shape[0] == target.shape[0]
|
| 64 |
+
pred = pred.reshape(pred.shape[0], -1)
|
| 65 |
+
target = target.reshape(target.shape[0], -1)
|
| 66 |
+
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
| 67 |
+
|
| 68 |
+
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
|
| 69 |
+
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
|
| 70 |
+
|
| 71 |
+
dice = num / den
|
| 72 |
+
dice = torch.mean(dice)
|
| 73 |
+
return 1 - dice
|
| 74 |
+
|
| 75 |
+
def forward(self, score, target):
|
| 76 |
+
ph, pw = score.size(2), score.size(3)
|
| 77 |
+
h, w = target.size(1), target.size(2)
|
| 78 |
+
if ph != h or pw != w:
|
| 79 |
+
score = F.upsample(
|
| 80 |
+
input=score, size=(h, w), mode='bilinear')
|
| 81 |
+
|
| 82 |
+
score = F.softmax(score,dim=1)
|
| 83 |
+
num_classes = score.shape[1]
|
| 84 |
+
|
| 85 |
+
one_hot_target = F.one_hot(
|
| 86 |
+
torch.clamp(target.long(), 0, num_classes - 1),
|
| 87 |
+
num_classes=num_classes)
|
| 88 |
+
valid_mask = (target != self.ignore_index).long()
|
| 89 |
+
|
| 90 |
+
loss = self.dice_loss(score, one_hot_target, valid_mask)
|
| 91 |
+
return loss
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class BinaryDiceLoss(nn.Module):
|
| 95 |
+
def __init__(self, smooth=1, exponent=2, ignore_label=-1): #because padding adds -1 to the targets
|
| 96 |
+
super(BinaryDiceLoss, self).__init__()
|
| 97 |
+
self.ignore_index = ignore_label
|
| 98 |
+
self.smooth = smooth
|
| 99 |
+
self.exponent = exponent
|
| 100 |
+
|
| 101 |
+
def binary_dice_loss(self, pred, target, valid_mask):
|
| 102 |
+
assert pred.shape[0] == target.shape[0]
|
| 103 |
+
pred = pred.reshape(pred.shape[0], -1)
|
| 104 |
+
target = target.reshape(target.shape[0], -1)
|
| 105 |
+
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
| 106 |
+
|
| 107 |
+
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
|
| 108 |
+
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
|
| 109 |
+
|
| 110 |
+
dice = num / den
|
| 111 |
+
dice = torch.mean(dice)
|
| 112 |
+
return 1 - dice
|
| 113 |
+
|
| 114 |
+
def forward(self, score, target):
|
| 115 |
+
ph, pw = score.size(2), score.size(3)
|
| 116 |
+
h, w = target.size(1), target.size(2)
|
| 117 |
+
if ph != h or pw != w:
|
| 118 |
+
score = F.upsample(
|
| 119 |
+
input=score, size=(h, w), mode='bilinear')
|
| 120 |
+
|
| 121 |
+
score = F.softmax(score,dim=1)
|
| 122 |
+
num_classes = score.shape[1]
|
| 123 |
+
|
| 124 |
+
one_hot_target = F.one_hot(
|
| 125 |
+
torch.clamp(target.long(), 0, num_classes - 1),
|
| 126 |
+
num_classes=num_classes)
|
| 127 |
+
valid_mask = (target != self.ignore_index).long()
|
| 128 |
+
|
| 129 |
+
loss = self.binary_dice_loss(
|
| 130 |
+
score[:, 1],
|
| 131 |
+
one_hot_target[..., 1],
|
| 132 |
+
valid_mask)
|
| 133 |
+
return loss
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class DiceEntropyLoss(nn.Module):
|
| 137 |
+
def __init__(self, smooth=1, exponent=2, ignore_label=-1, weight=None): #because padding adds -1 to the targets
|
| 138 |
+
super(DiceEntropyLoss, self).__init__()
|
| 139 |
+
self.ignore_label = ignore_label
|
| 140 |
+
self.smooth = smooth
|
| 141 |
+
self.exponent = exponent
|
| 142 |
+
self.cross_entropy = nn.CrossEntropyLoss(weight=weight,
|
| 143 |
+
ignore_index=ignore_label)
|
| 144 |
+
|
| 145 |
+
def binary_dice_loss(self, pred, target, valid_mask):
|
| 146 |
+
print(pred.shape, target.shape, "this is second list")
|
| 147 |
+
assert pred.shape[0] == target.shape[0]
|
| 148 |
+
pred = pred.reshape(pred.shape[0], -1)
|
| 149 |
+
target = target.reshape(target.shape[0], -1)
|
| 150 |
+
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
| 151 |
+
|
| 152 |
+
# print(pred.shape, target.shape)
|
| 153 |
+
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
|
| 154 |
+
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
|
| 155 |
+
|
| 156 |
+
dice = num / den
|
| 157 |
+
dice = torch.mean(dice)
|
| 158 |
+
return 1 - dice
|
| 159 |
+
|
| 160 |
+
def forward(self, score, target):
|
| 161 |
+
ph, pw = score.size(2), score.size(3)
|
| 162 |
+
h, w = target.size(1), target.size(2)
|
| 163 |
+
# if ph != h or pw != w:
|
| 164 |
+
# score = F.upsample(
|
| 165 |
+
# input=score, size=(h, w), mode='bilinear')
|
| 166 |
+
CE_loss = self.cross_entropy(score, target)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
score = F.softmax(score,dim=1)
|
| 170 |
+
num_classes = score.shape[1]
|
| 171 |
+
|
| 172 |
+
one_hot_target = F.one_hot(
|
| 173 |
+
torch.clamp(target.long(), 0, num_classes - 1),
|
| 174 |
+
num_classes=num_classes)
|
| 175 |
+
valid_mask = (target != self.ignore_label).long()
|
| 176 |
+
# print(score.shape,one_hot_target.shape)
|
| 177 |
+
dice_loss = self.binary_dice_loss(
|
| 178 |
+
score[:, 1],
|
| 179 |
+
one_hot_target[:,1][..., 1],
|
| 180 |
+
valid_mask)
|
| 181 |
+
|
| 182 |
+
return 0.3*CE_loss + 0.7*dice_loss
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class FocalLoss(nn.Module):
|
| 188 |
+
|
| 189 |
+
def __init__(self, alpha=0.25, gamma=2., ignore_label=-1): #alpha 0.25, gamma=2.
|
| 190 |
+
super(FocalLoss, self).__init__()
|
| 191 |
+
self.alpha=alpha
|
| 192 |
+
self.gamma= gamma
|
| 193 |
+
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label, reduction="none")
|
| 194 |
+
|
| 195 |
+
def forward(self, score, target):
|
| 196 |
+
ph, pw = score.size(2), score.size(3)
|
| 197 |
+
h, w = target.size(1), target.size(2)
|
| 198 |
+
if ph != h or pw != w:
|
| 199 |
+
score = F.upsample(
|
| 200 |
+
input=score, size=(h, w), mode='bilinear')
|
| 201 |
+
|
| 202 |
+
ce_loss = self.criterion(score, target)
|
| 203 |
+
pt = torch.exp(-ce_loss)
|
| 204 |
+
f_loss = self.alpha * (1-pt)**self.gamma * ce_loss
|
| 205 |
+
return f_loss.mean()
|
| 206 |
+
|
| 207 |
+
|
utils/losses.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
| 2 |
+
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
|
| 3 |
+
#
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
# This work should only be used for nonprofit purposes.
|
| 6 |
+
#
|
| 7 |
+
# By downloading and/or using any of these files, you implicitly agree to all the
|
| 8 |
+
# terms of the license, as specified in the document LICENSE.txt
|
| 9 |
+
# (included in this package) and online at
|
| 10 |
+
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Created in September 2022
|
| 14 |
+
@author: fabrizio.guillaro
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from torch.nn import functional as F
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CrossEntropy(nn.Module):
|
| 24 |
+
def __init__(self, ignore_label=-1, weight=None):
|
| 25 |
+
super(CrossEntropy, self).__init__()
|
| 26 |
+
self.ignore_label = ignore_label
|
| 27 |
+
self.criterion = nn.CrossEntropyLoss(weight=weight,
|
| 28 |
+
ignore_index=ignore_label)
|
| 29 |
+
|
| 30 |
+
def forward(self, score, target):
|
| 31 |
+
ph, pw = score.size(2), score.size(3)
|
| 32 |
+
h, w = target.size(1), target.size(2)
|
| 33 |
+
if ph != h or pw != w:
|
| 34 |
+
score = F.upsample(
|
| 35 |
+
input=score, size=(h, w), mode='bilinear')
|
| 36 |
+
|
| 37 |
+
loss = self.criterion(score, target)
|
| 38 |
+
return loss
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DiceLoss(nn.Module):
|
| 43 |
+
def __init__(self, ignore_label=-1, smooth=1, exponent=2): #because padding adds -1 to the targets
|
| 44 |
+
super(DiceLoss, self).__init__()
|
| 45 |
+
self.ignore_index = ignore_label
|
| 46 |
+
self.smooth = smooth
|
| 47 |
+
self.exponent = exponent
|
| 48 |
+
|
| 49 |
+
def dice_loss(self, pred, target, valid_mask):
|
| 50 |
+
assert pred.shape[0] == target.shape[0]
|
| 51 |
+
total_loss = 0
|
| 52 |
+
num_classes = pred.shape[1]
|
| 53 |
+
for i in range(num_classes):
|
| 54 |
+
if i != self.ignore_index:
|
| 55 |
+
dice_loss = self.binary_dice_loss(
|
| 56 |
+
pred[:, i],
|
| 57 |
+
target[..., i],
|
| 58 |
+
valid_mask=valid_mask,)
|
| 59 |
+
total_loss += dice_loss
|
| 60 |
+
return total_loss / num_classes
|
| 61 |
+
|
| 62 |
+
def binary_dice_loss(self, pred, target, valid_mask):
|
| 63 |
+
assert pred.shape[0] == target.shape[0]
|
| 64 |
+
pred = pred.reshape(pred.shape[0], -1)
|
| 65 |
+
target = target.reshape(target.shape[0], -1)
|
| 66 |
+
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
| 67 |
+
|
| 68 |
+
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
|
| 69 |
+
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
|
| 70 |
+
|
| 71 |
+
dice = num / den
|
| 72 |
+
dice = torch.mean(dice)
|
| 73 |
+
return 1 - dice
|
| 74 |
+
|
| 75 |
+
def forward(self, score, target):
|
| 76 |
+
ph, pw = score.size(2), score.size(3)
|
| 77 |
+
h, w = target.size(1), target.size(2)
|
| 78 |
+
if ph != h or pw != w:
|
| 79 |
+
score = F.upsample(
|
| 80 |
+
input=score, size=(h, w), mode='bilinear')
|
| 81 |
+
|
| 82 |
+
score = F.softmax(score,dim=1)
|
| 83 |
+
num_classes = score.shape[1]
|
| 84 |
+
|
| 85 |
+
one_hot_target = F.one_hot(
|
| 86 |
+
torch.clamp(target.long(), 0, num_classes - 1),
|
| 87 |
+
num_classes=num_classes)
|
| 88 |
+
valid_mask = (target != self.ignore_index).long()
|
| 89 |
+
|
| 90 |
+
loss = self.dice_loss(score, one_hot_target, valid_mask)
|
| 91 |
+
return loss
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class BinaryDiceLoss(nn.Module):
|
| 96 |
+
def __init__(self, smooth=1, exponent=2, ignore_label=-1): #because padding adds -1 to the targets
|
| 97 |
+
super(BinaryDiceLoss, self).__init__()
|
| 98 |
+
self.ignore_index = ignore_label
|
| 99 |
+
self.smooth = smooth
|
| 100 |
+
self.exponent = exponent
|
| 101 |
+
|
| 102 |
+
def binary_dice_loss(self, pred, target, valid_mask):
|
| 103 |
+
assert pred.shape[0] == target.shape[0]
|
| 104 |
+
print(pred.shape, target.shape)
|
| 105 |
+
pred = pred.reshape(pred.shape[0], -1)
|
| 106 |
+
target = target.reshape(target.shape[0], -1)
|
| 107 |
+
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
| 108 |
+
print(pred.shape, target.shape)
|
| 109 |
+
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
|
| 110 |
+
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
|
| 111 |
+
|
| 112 |
+
dice = num / den
|
| 113 |
+
dice = torch.mean(dice)
|
| 114 |
+
return 1 - dice
|
| 115 |
+
|
| 116 |
+
def forward(self, score, target):
|
| 117 |
+
|
| 118 |
+
ph, pw = score.size(2), score.size(3)
|
| 119 |
+
h, w = target.size(2), target.size(3)
|
| 120 |
+
if ph != h or pw != w:
|
| 121 |
+
score = F.upsample(
|
| 122 |
+
input=score, size=(h, w), mode='bilinear')
|
| 123 |
+
|
| 124 |
+
score = F.softmax(score,dim=1)
|
| 125 |
+
num_classes = score.shape[1]
|
| 126 |
+
|
| 127 |
+
one_hot_target = F.one_hot(
|
| 128 |
+
torch.clamp(target.long(), 0, num_classes - 1),
|
| 129 |
+
num_classes=num_classes)
|
| 130 |
+
valid_mask = (target != self.ignore_index).long()
|
| 131 |
+
|
| 132 |
+
loss = self.binary_dice_loss(
|
| 133 |
+
score[:, 1],
|
| 134 |
+
one_hot_target[..., 1],
|
| 135 |
+
valid_mask)
|
| 136 |
+
return loss
|
| 137 |
+
|
| 138 |
+
def create_target_from_mask_and_label(mask, data_label):
|
| 139 |
+
"""
|
| 140 |
+
Convert binary mask to class-labeled target based on data_label.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
mask: B H W with values 0 (black/background) or 1 (white/foreground)
|
| 144 |
+
data_label: B×1 tensor or B tensor with values [0, 1, 2, 3]
|
| 145 |
+
- 0: background (no edit)
|
| 146 |
+
- 1: physical edit (Photoshop)
|
| 147 |
+
- 2: synthetic AI edit
|
| 148 |
+
- 3: other edit type
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
target: B H W with values [0, 1, 2, 3]
|
| 152 |
+
- 0: unedited pixels (mask == 0)
|
| 153 |
+
- 1, 2, 3: edited pixels with their respective class labels
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
# Handle if mask has channel dimension
|
| 157 |
+
if mask.dim() == 4: # B×1×H×W
|
| 158 |
+
mask = mask.squeeze(1) # B×H×W
|
| 159 |
+
|
| 160 |
+
# Handle if data_label has extra dimensions
|
| 161 |
+
if data_label.dim() > 1:
|
| 162 |
+
data_label = data_label.squeeze() # B
|
| 163 |
+
|
| 164 |
+
B, H, W = mask.shape
|
| 165 |
+
|
| 166 |
+
# Initialize target with zeros (background class)
|
| 167 |
+
target = torch.zeros(B, H, W, dtype=torch.long, device=mask.device)
|
| 168 |
+
|
| 169 |
+
# For each sample in batch
|
| 170 |
+
for b in range(B):
|
| 171 |
+
# Get the class label for this sample
|
| 172 |
+
class_label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
|
| 173 |
+
|
| 174 |
+
# Where mask is white (1), set the target to the class label
|
| 175 |
+
# Where mask is black (0), keep target as 0 (background)
|
| 176 |
+
target[b][mask[b] == 1] = class_label
|
| 177 |
+
|
| 178 |
+
return target
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def debug_target_creation(target, data_label, batch_size=4):
|
| 182 |
+
"""
|
| 183 |
+
Debug function to print data_label and target mapping before and after conversion.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
target: Binary mask B×H×W or B×1×H×W with values 0 or 1
|
| 187 |
+
data_label: B tensor with class labels [0, 1, 2, 3]
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
print("="*80)
|
| 191 |
+
print("DEBUGGING TARGET CREATION")
|
| 192 |
+
print("="*80)
|
| 193 |
+
|
| 194 |
+
# Print original inputs
|
| 195 |
+
print("\n--- BEFORE CONVERSION ---")
|
| 196 |
+
print(f"Data Label shape: {data_label.shape}")
|
| 197 |
+
print(f"Data Label values: {data_label}")
|
| 198 |
+
print(f"Data Label dtype: {data_label.dtype}")
|
| 199 |
+
|
| 200 |
+
print(f"\nTarget (mask) shape: {target.shape}")
|
| 201 |
+
print(f"Target (mask) unique values: {torch.unique(target)}")
|
| 202 |
+
print(f"Target (mask) dtype: {target.dtype}")
|
| 203 |
+
|
| 204 |
+
# Print per-sample details BEFORE
|
| 205 |
+
print("\n--- PER-SAMPLE BREAKDOWN (BEFORE) ---")
|
| 206 |
+
if target.dim() == 4: # B×1×H×W
|
| 207 |
+
target_2d = target.squeeze(1)
|
| 208 |
+
else:
|
| 209 |
+
target_2d = target
|
| 210 |
+
|
| 211 |
+
B = target_2d.shape[0]
|
| 212 |
+
for b in range(min(B, batch_size)):
|
| 213 |
+
edited_pixels = (target_2d[b] == 1).sum().item()
|
| 214 |
+
total_pixels = target_2d[b].numel()
|
| 215 |
+
label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
|
| 216 |
+
print(f" Sample {b}: Label={label}, Edited pixels={edited_pixels}/{total_pixels}")
|
| 217 |
+
|
| 218 |
+
# Create target
|
| 219 |
+
target_converted = create_target_from_mask_and_label(target, data_label)
|
| 220 |
+
|
| 221 |
+
# Print AFTER conversion
|
| 222 |
+
print("\n--- AFTER CONVERSION ---")
|
| 223 |
+
print(f"Target (converted) shape: {target_converted.shape}")
|
| 224 |
+
print(f"Target (converted) unique values: {torch.unique(target_converted)}")
|
| 225 |
+
print(f"Target (converted) dtype: {target_converted.dtype}")
|
| 226 |
+
|
| 227 |
+
# Print per-sample details AFTER
|
| 228 |
+
print("\n--- PER-SAMPLE BREAKDOWN (AFTER) ---")
|
| 229 |
+
for b in range(min(B, batch_size)):
|
| 230 |
+
label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
|
| 231 |
+
|
| 232 |
+
# Count pixels for each class
|
| 233 |
+
class_counts = {}
|
| 234 |
+
for class_id in range(4):
|
| 235 |
+
count = (target_converted[b] == class_id).sum().item()
|
| 236 |
+
class_counts[class_id] = count
|
| 237 |
+
|
| 238 |
+
print(f" Sample {b}:")
|
| 239 |
+
print(f" Label (expected): {label}")
|
| 240 |
+
print(f" Class distribution: {class_counts}")
|
| 241 |
+
|
| 242 |
+
# Verify correctness
|
| 243 |
+
if label == 0:
|
| 244 |
+
# All pixels should be background (0)
|
| 245 |
+
if class_counts[0] == target_converted[b].numel():
|
| 246 |
+
print(f" ✓ CORRECT: All pixels are class 0 (background)")
|
| 247 |
+
else:
|
| 248 |
+
print(f" ✗ ERROR: Expected all pixels to be 0, but got {class_counts}")
|
| 249 |
+
else:
|
| 250 |
+
# Non-background pixels should have the label
|
| 251 |
+
if class_counts[label] > 0:
|
| 252 |
+
print(f" ✓ CORRECT: Found {class_counts[label]} pixels with class {label}")
|
| 253 |
+
else:
|
| 254 |
+
print(f" ✗ ERROR: Expected class {label} pixels but found none")
|
| 255 |
+
|
| 256 |
+
print("\n" + "="*80)
|
| 257 |
+
|
| 258 |
+
return target_converted
|
| 259 |
+
|
| 260 |
+
class MultiClassDiceEntropyLoss(nn.Module):
|
| 261 |
+
"""
|
| 262 |
+
Multi-class segmentation loss combining Dice and CrossEntropy.
|
| 263 |
+
Supports classes: 0 (background), 1, 2, 3
|
| 264 |
+
"""
|
| 265 |
+
def __init__(self, num_classes=4, smooth=1e-5, dice_weight=0.5, ce_weight=0.5,
|
| 266 |
+
ignore_index=-1):
|
| 267 |
+
super(MultiClassDiceEntropyLoss, self).__init__()
|
| 268 |
+
self.num_classes = num_classes
|
| 269 |
+
self.smooth = smooth
|
| 270 |
+
self.dice_weight = dice_weight
|
| 271 |
+
self.ce_weight = ce_weight
|
| 272 |
+
self.ignore_index = ignore_index
|
| 273 |
+
|
| 274 |
+
# CrossEntropy loss
|
| 275 |
+
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
|
| 276 |
+
|
| 277 |
+
def dice_loss(self, pred, target, valid_mask=None):
|
| 278 |
+
"""
|
| 279 |
+
Compute Dice loss per class and average
|
| 280 |
+
|
| 281 |
+
pred: B C H W (softmax probabilities)
|
| 282 |
+
target: B H W (class indices 0-3)
|
| 283 |
+
valid_mask: B H W (1 for valid, 0 for ignore)
|
| 284 |
+
"""
|
| 285 |
+
dice_losses = []
|
| 286 |
+
|
| 287 |
+
for class_id in range(self.num_classes):
|
| 288 |
+
# One-hot encode for this class
|
| 289 |
+
pred_class = pred[:, class_id, :, :] # B×H×W
|
| 290 |
+
target_class = (target == class_id).float() # B×H×W
|
| 291 |
+
|
| 292 |
+
# Flatten
|
| 293 |
+
pred_flat = pred_class.reshape(-1)
|
| 294 |
+
target_flat = target_class.reshape(-1)
|
| 295 |
+
|
| 296 |
+
# Apply valid mask if provided
|
| 297 |
+
if valid_mask is not None:
|
| 298 |
+
valid_flat = valid_mask.reshape(-1)
|
| 299 |
+
pred_flat = pred_flat * valid_flat
|
| 300 |
+
target_flat = target_flat * valid_flat
|
| 301 |
+
|
| 302 |
+
# Dice computation
|
| 303 |
+
intersection = torch.sum(pred_flat * target_flat)
|
| 304 |
+
union = torch.sum(pred_flat) + torch.sum(target_flat)
|
| 305 |
+
|
| 306 |
+
dice = (2 * intersection + self.smooth) / (union + self.smooth)
|
| 307 |
+
dice_losses.append(1 - dice)
|
| 308 |
+
|
| 309 |
+
return torch.mean(torch.stack(dice_losses))
|
| 310 |
+
|
| 311 |
+
def forward(self, score, target, data_label):
|
| 312 |
+
"""
|
| 313 |
+
pred: B 1 H W (U-Net output, raw logits)
|
| 314 |
+
target: B H W (class labels: 0, 1, 2, or 3)
|
| 315 |
+
"""
|
| 316 |
+
# Handle if target has channel dimension
|
| 317 |
+
if target.dim() == 4: # B×1×H×W
|
| 318 |
+
target = target.squeeze(1) # B×H×W
|
| 319 |
+
# target = create_target_from_mask_and_label(target, data_label)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# test_result = debug_target_creation(target, data_label, batch_size=1)
|
| 323 |
+
|
| 324 |
+
# Ensure target i'=s long type
|
| 325 |
+
target = target.long()
|
| 326 |
+
|
| 327 |
+
# Upsample pred if needed
|
| 328 |
+
if score.shape[2:] != target.shape[1:]:
|
| 329 |
+
score = F.interpolate(score, size=target.shape[1:], mode='bilinear', align_corners=False)
|
| 330 |
+
|
| 331 |
+
# Convert single channel to multi-class
|
| 332 |
+
# If score is B×1×H×W, we need to expand it to B×C×H×W
|
| 333 |
+
# if score.shape[1] == 1:
|
| 334 |
+
# # U-Net outputs 1 channel, we need to create num_classes channels
|
| 335 |
+
# # This assumes your U-Net needs modification OR
|
| 336 |
+
# # we convert single channel to multi-class logits
|
| 337 |
+
# raise ValueError(
|
| 338 |
+
# f"U-Net outputs {score.shape[1]} channel but {self.num_classes} classes expected. "
|
| 339 |
+
# "Modify U-Net output layer to have num_classes={} channels".format(self.num_classes)
|
| 340 |
+
# )
|
| 341 |
+
|
| 342 |
+
# Apply softmax to get probabilities
|
| 343 |
+
score_probs = F.softmax(score, dim=1) # B×C×H×W
|
| 344 |
+
|
| 345 |
+
# CrossEntropy loss
|
| 346 |
+
ce_loss = self.ce_loss(score, target)
|
| 347 |
+
|
| 348 |
+
# Valid mask (exclude ignore_index)
|
| 349 |
+
valid_mask = (target != self.ignore_index).float()
|
| 350 |
+
|
| 351 |
+
# Dice loss
|
| 352 |
+
dice_loss = self.dice_loss(score_probs, target, valid_mask)
|
| 353 |
+
|
| 354 |
+
# Combined loss
|
| 355 |
+
total_loss = self.dice_weight * dice_loss + self.ce_weight * ce_loss
|
| 356 |
+
|
| 357 |
+
return total_loss
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class DiceEntropyLoss(nn.Module):
|
| 361 |
+
def __init__(self, smooth=1, exponent=2, ignore_label=-1, weight=None): #because padding adds -1 to the targets
|
| 362 |
+
super(DiceEntropyLoss, self).__init__()
|
| 363 |
+
self.ignore_label = ignore_label
|
| 364 |
+
self.smooth = smooth
|
| 365 |
+
self.exponent = exponent
|
| 366 |
+
self.cross_entropy = nn.CrossEntropyLoss(weight=weight,
|
| 367 |
+
ignore_index=ignore_label)
|
| 368 |
+
|
| 369 |
+
def binary_dice_loss(self, pred, target, valid_mask):
|
| 370 |
+
assert pred.shape[0] == target.shape[0]
|
| 371 |
+
pred = pred.reshape(pred.shape[0], -1)
|
| 372 |
+
target = target.reshape(target.shape[0], -1)
|
| 373 |
+
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
| 374 |
+
|
| 375 |
+
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
|
| 376 |
+
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
|
| 377 |
+
|
| 378 |
+
dice = num / den
|
| 379 |
+
dice = torch.mean(dice)
|
| 380 |
+
return 1 - dice
|
| 381 |
+
|
| 382 |
+
def forward(self, score, target):
|
| 383 |
+
target = target.squeeze(1).long()
|
| 384 |
+
|
| 385 |
+
target = torch.clamp(target, min=0, max=1)
|
| 386 |
+
ph, pw = score.size(2), score.size(3) # (B,1,224,224)
|
| 387 |
+
h, w = target.size(1), target.size(2) # (B,224,224)
|
| 388 |
+
if ph != h or pw != w:
|
| 389 |
+
score = F.upsample(
|
| 390 |
+
input=score, size=(h, w), mode='bilinear')
|
| 391 |
+
|
| 392 |
+
CE_loss = self.cross_entropy(score, target)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
score = F.softmax(score,dim=1)
|
| 396 |
+
num_classes = score.shape[1]
|
| 397 |
+
|
| 398 |
+
one_hot_target = F.one_hot(
|
| 399 |
+
torch.clamp(target.long(), 0, num_classes - 1),
|
| 400 |
+
num_classes=num_classes)
|
| 401 |
+
valid_mask = (target != self.ignore_label).long()
|
| 402 |
+
|
| 403 |
+
# dice_loss = self.binary_dice_loss(
|
| 404 |
+
# score[:, 1],
|
| 405 |
+
# one_hot_target[..., 1],
|
| 406 |
+
# valid_mask)
|
| 407 |
+
|
| 408 |
+
number_of_present_classes = 4
|
| 409 |
+
dice_loss = 0
|
| 410 |
+
for class_id in [1,2,3]:
|
| 411 |
+
if (target == class_id).sum() > 0:
|
| 412 |
+
dice_loss += dice(pred[:, class_id], target_onehot[:, class_id])
|
| 413 |
+
dice_loss /= number_of_present_classes
|
| 414 |
+
|
| 415 |
+
return 0.3*CE_loss + 0.7*dice_loss
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class FocalLoss(nn.Module):
|
| 421 |
+
|
| 422 |
+
def __init__(self, alpha=0.25, gamma=2., ignore_label=-1): #alpha 0.25, gamma=2.
|
| 423 |
+
super(FocalLoss, self).__init__()
|
| 424 |
+
self.alpha=alpha
|
| 425 |
+
self.gamma= gamma
|
| 426 |
+
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label, reduction="none")
|
| 427 |
+
|
| 428 |
+
def forward(self, score, target):
|
| 429 |
+
ph, pw = score.size(2), score.size(3)
|
| 430 |
+
h, w = target.size(1), target.size(2)
|
| 431 |
+
if ph != h or pw != w:
|
| 432 |
+
score = F.upsample(
|
| 433 |
+
input=score, size=(h, w), mode='bilinear')
|
| 434 |
+
|
| 435 |
+
ce_loss = self.criterion(score, target)
|
| 436 |
+
pt = torch.exp(-ce_loss)
|
| 437 |
+
f_loss = self.alpha * (1-pt)**self.gamma * ce_loss
|
| 438 |
+
return f_loss.mean()
|
| 439 |
+
|
| 440 |
+
|