Shahidmuneer commited on
Commit
8bd3ef8
·
verified ·
1 Parent(s): 2449aad

Upload folder using huggingface_hub

Browse files
.dockerignore ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ *.so
7
+ *.egg
8
+ *.egg-info
9
+ dist
10
+ build
11
+
12
+ .venv
13
+ venv
14
+ ENV
15
+ env
16
+
17
+ .git
18
+ .gitignore
19
+ .idea
20
+ .vscode
21
+ .claude
22
+
23
+ *.md
24
+ README.md
25
+ Dockerfile
26
+ .dockerignore
27
+
28
+ test_*.http
29
+ test_results
30
+ scripts/test_datasets
31
+
32
+ .pytest_cache
33
+ .coverage
34
+ htmlcov
35
+
36
+ *.log
37
+ .DS_Store
38
+ .python-version
.env.example ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # App Configuration
2
+ APP_NAME="ML Inference Service"
3
+ APP_VERSION="0.1.0"
4
+ DEBUG=false
5
+
6
+ # Server
7
+ HOST="0.0.0.0"
8
+ PORT=8000
9
+
10
+ # Model
11
+ MODEL_NAME="microsoft/resnet-18"
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ dyff-outputs/
2
+ models/
3
+ test_datasets/
4
+ test_results/
5
+ venv/
6
+ **/__pycache__
7
+ *.tmp
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12.11
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE="1" \
4
+ PYTHONUNBUFFERED="1"
5
+
6
+ # hadolint ignore=DL3013
7
+ RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel
8
+
9
+ WORKDIR /app/
10
+
11
+ COPY requirements.cpu.txt ./
12
+ RUN python3 -m pip install --no-cache-dir -r ./requirements.cpu.txt
13
+
14
+ COPY requirements.torch.cpu.txt ./
15
+ RUN python3 -m pip install --no-cache-dir -r ./requirements.torch.cpu.txt
16
+
17
+ COPY app ./app
18
+ COPY models ./models
19
+ COPY main.py ./
20
+
21
+ EXPOSE 8000
22
+
23
+ # ENTRYPOINT ["python3", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
24
+ ENTRYPOINT ["python3", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ # ML Inference Service
6
+
7
+ FastAPI service for serving ML models over HTTP. Comes with ResNet-18 for image classification out of the box, but you can swap in any model you want.
8
+
9
+ ## Quick Start
10
+
11
+ **Install `uv`:**
12
+ https://docs.astral.sh/uv/getting-started/installation/
13
+
14
+ **Local development:**
15
+ ```bash
16
+ # Install dependencies
17
+ make setup
18
+ source venv/bin/activate
19
+
20
+ # Download the example model
21
+ make download
22
+
23
+ # Run it
24
+ make serve
25
+ ```
26
+
27
+ In a second terminal:
28
+ ```bash
29
+ # Process an example input
30
+ ./prompt.sh cat.json
31
+ ```
32
+
33
+ Server runs on `http://127.0.0.1:8000`. Check `/docs` for the interactive API documentation.
34
+
35
+ **Docker:**
36
+ ```bash
37
+ # Build
38
+ make docker-build
39
+
40
+ # Run
41
+ make docker-run
42
+ ```
43
+
44
+ ## Testing the API
45
+
46
+ ```bash
47
+ # Using curl
48
+ curl -X POST http://localhost:8000/predict \
49
+ -H "Content-Type: application/json" \
50
+ -d '{
51
+ "image": {
52
+ "mediaType": "image/jpeg",
53
+ "data": "<base64-encoded-image>"
54
+ }
55
+ }'
56
+ ```
57
+
58
+ Example response:
59
+ ```json
60
+ {
61
+ "logprobs": [-0.859380304813385,-1.2701971530914307,-2.1918208599090576,-1.69235098361969],
62
+ "localizationMask": {
63
+ "mediaType":"image/png",
64
+ "data":"iVBORw0KGgoAAAANSUhEUgAAA8AAAAKDAQAAAAD9Fl5AAAAAu0lEQVR4nO3NsREAMAgDMWD/nZMVKEwn1T5/FQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMCl3g5f+HC24TRhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAj70gwKsTlmdBwAAAABJRU5ErkJggg=="
65
+ }
66
+ }
67
+ ```
68
+
69
+ ## Project Structure
70
+
71
+ ```
72
+ example-submission/
73
+ ├── main.py # Entry point
74
+ ├── app/
75
+ │ ├── core/
76
+ │ │ ├── app.py # <= INSTANTIATE YOUR DETECTOR HERE
77
+ │ │ └── logging.py # Logging setup
78
+ │ ├── api/
79
+ │ │ ├── models.py # Request/response schemas
80
+ │ │ ├── controllers.py # Business logic
81
+ │ │ └── routes/
82
+ │ │ └── prediction.py # POST /predict
83
+ │ └── services/
84
+ │ ├── base.py # <= YOUR DETECTOR IMPLEMENTS THIS INTERFACE
85
+ │ └── inference.py # Example service based on ResNet-18
86
+ ├── models/
87
+ │ └── microsoft/
88
+ │ └── resnet-18/ # Model weights and config
89
+ ├── scripts/
90
+ │ ├── model_download.bash
91
+ │ ├── generate_test_datasets.py
92
+ │ └── test_datasets.py
93
+ ├── Dockerfile
94
+ ├── .env.example # Environment config template
95
+ ├── cat.json # An example /predict request object
96
+ ├── makefile
97
+ ├── prompt.sh # Script that makes a /predict request
98
+ ├── requirements.in
99
+ ├── requirements.txt
100
+ ├── response.json # An example /predict response object
101
+ └──
102
+ ```
103
+
104
+ ## How to Plug In Your Own Model
105
+
106
+ To integrate your model, implement the `InferenceService` abstract class defined in `app/services/base.py`. You can follow the example implementation in `app/services/inference.py`, which is based on ResNet-18. After implementing the required interface, instantiate your model in the `lifespan()` function in `app/core/app.py`, replacing the `ResNetInferenceService` instance.
107
+
108
+ ### Step 1: Create Your Service Class
109
+
110
+ ```python
111
+ # app/services/your_model_service.py
112
+ from app.services.base import InferenceService
113
+ from app.api.models import ImageRequest, PredictionResponse
114
+
115
+ class YourModelService(InferenceService[ImageRequest, PredictionResponse]):
116
+ def __init__(self, model_name: str):
117
+ self.model_name = model_name
118
+ self.model_path = f"models/{model_name}"
119
+ self.model = None
120
+ self._is_loaded = False
121
+
122
+ def load_model(self) -> None:
123
+ """Load your model here. Called once at startup."""
124
+ self.model = load_your_model(self.model_path)
125
+ self._is_loaded = True
126
+
127
+ def predict(self, request: ImageRequest) -> PredictionResponse:
128
+ """Actual inference happens here."""
129
+ image = decode_base64_image(request.image.data)
130
+ result = self.model(image)
131
+
132
+ logprobs = ...
133
+ mask = ...
134
+
135
+ return PredictionResponse(
136
+ logprobs=logprobs,
137
+ localizationMask=mask,
138
+ )
139
+
140
+ @property
141
+ def is_loaded(self) -> bool:
142
+ return self._is_loaded
143
+ ```
144
+
145
+ ### Step 2: Register Your Service
146
+
147
+ Open `app/core/app.py` and find the lifespan function:
148
+
149
+ ```python
150
+ # Change this line:
151
+ service = ResNetInferenceService(model_name="microsoft/resnet-18")
152
+
153
+ # To this:
154
+ service = YourModelService(...)
155
+ ```
156
+
157
+ That's it. The `/predict` endpoint now serves your model.
158
+
159
+ ### Model Files
160
+
161
+ Put your model files under the `models/` directory:
162
+
163
+ ```
164
+ models/
165
+ └── your-org/
166
+ └── your-model/
167
+ ├── config.json
168
+ ├── weights.bin
169
+ └── (other files)
170
+ ```
171
+
172
+ ## Configuration
173
+
174
+ Settings are managed via environment variables or a `.env` file. See `.env.example` for all available options.
175
+
176
+ **Default values:**
177
+ - `APP_NAME`: "ML Inference Service"
178
+ - `APP_VERSION`: "0.1.0"
179
+ - `DEBUG`: false
180
+ - `HOST`: "0.0.0.0"
181
+ - `PORT`: 8000
182
+ - `MODEL_NAME`: "microsoft/resnet-18"
183
+
184
+ **To customize:**
185
+ ```bash
186
+ # Copy the example
187
+ cp .env.example .env
188
+
189
+ # Edit values
190
+ vim .env
191
+ ```
192
+
193
+ Or set environment variables directly:
194
+ ```bash
195
+ export MODEL_NAME="google/vit-base-patch16-224"
196
+ uvicorn main:app --reload
197
+ ```
198
+
199
+ ## Deployment
200
+
201
+ **Development:**
202
+ ```bash
203
+ uvicorn main:app --reload
204
+ ```
205
+
206
+ **Production:**
207
+ ```bash
208
+ gunicorn main:app -w 4 -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
209
+ ```
210
+
211
+ The service runs on CPU by default. For GPU inference, install CUDA-enabled PyTorch and modify your service to move tensors to the GPU device.
212
+
213
+ **Docker:**
214
+ - Multi-stage build keeps the image small
215
+ - Runs as non-root user (`appuser`)
216
+ - Python dependencies installed in user site-packages
217
+ - Model files baked into the image
218
+
219
+ ## What Happens When You Start the Server
220
+
221
+ ```
222
+ INFO: Starting ML Inference Service...
223
+ INFO: Initializing ResNet service: models/microsoft/resnet-18
224
+ INFO: Loading model from models/microsoft/resnet-18
225
+ INFO: Model loaded: 1000 classes
226
+ INFO: Startup completed successfully
227
+ INFO: Uvicorn running on http://0.0.0.0:8000
228
+ ```
229
+
230
+ If you see "Model directory not found", check that your model files exist at the expected path with the full org/model structure.
231
+
232
+ ## API Reference
233
+
234
+ **Endpoint:** `POST /predict`
235
+
236
+ **Request:**
237
+ ```json
238
+ {
239
+ "image": {
240
+ "mediaType": "image/jpeg", // or "image/png"
241
+ "data": "<base64 string>"
242
+ }
243
+ }
244
+ ```
245
+
246
+ **Response:**
247
+ ```json
248
+ {
249
+ "logprobs": [float], // Log-probabilities of each label
250
+ "localizationMask": { // [Optional] binary mask
251
+ "mediaType": "image/png", // Always png
252
+ "data": "<base64 string>" // Image data
253
+ }
254
+ }
255
+ ```
256
+
257
+ **Docs:**
258
+ - Swagger UI: `http://localhost:8000/docs`
259
+ - ReDoc: `http://localhost:8000/redoc`
260
+ - OpenAPI JSON: `http://localhost:8000/openapi.json`
261
+
262
+ ## PyArrow Test Datasets
263
+
264
+ We've included a test dataset system for validating your model. It generates 100 standardized test cases covering normal inputs, edge cases, performance benchmarks, and model comparisons.
265
+
266
+ ### Generate Datasets
267
+
268
+ ```bash
269
+ python scripts/generate_test_datasets.py
270
+ ```
271
+
272
+ This creates:
273
+ - `scripts/test_datasets/*.parquet` - Test data (images, requests, expected responses)
274
+ - `scripts/test_datasets/*_metadata.json` - Human-readable descriptions
275
+ - `scripts/test_datasets/datasets_summary.json` - Overview of all datasets
276
+
277
+ ### Run Tests
278
+
279
+ ```bash
280
+ # Start your service first
281
+ make serve
282
+ ```
283
+
284
+ In another terminal:
285
+
286
+ ```bash
287
+ # Quick test (5 samples per dataset)
288
+ python scripts/test_datasets.py --quick
289
+
290
+ # Full validation
291
+ python scripts/test_datasets.py
292
+
293
+ # Test specific category
294
+ python scripts/test_datasets.py --category edge_case
295
+ ```
296
+
297
+ ### Dataset Categories (25 datasets each)
298
+
299
+ **1. Standard Tests** (`standard_test_*.parquet`)
300
+ - Normal images: random patterns, shapes, gradients
301
+ - Common sizes: 224x224, 256x256, 299x299, 384x384
302
+ - Formats: JPEG, PNG
303
+ - Purpose: Baseline validation
304
+
305
+ **2. Edge Cases** (`edge_case_*.parquet`)
306
+ - Tiny images (32x32, 1x1)
307
+ - Huge images (2048x2048)
308
+ - Extreme aspect ratios (1000x50)
309
+ - Corrupted data, malformed requests
310
+ - Purpose: Test error handling
311
+
312
+ **3. Performance Benchmarks** (`performance_test_*.parquet`)
313
+ - Batch sizes: 1, 5, 10, 25, 50, 100 images
314
+ - Latency and throughput tracking
315
+ - Purpose: Performance profiling
316
+
317
+ **4. Model Comparisons** (`model_comparison_*.parquet`)
318
+ - Same inputs across different architectures
319
+ - Models: ResNet-18/50, ViT, ConvNext, Swin
320
+ - Purpose: Cross-model benchmarking
321
+
322
+ ### Test Output
323
+
324
+ ```
325
+ DATASET TESTING SUMMARY
326
+ ============================================================
327
+ Datasets tested: 100
328
+ Successful datasets: 95
329
+ Failed datasets: 5
330
+ Total samples: 1,247
331
+ Overall success rate: 87.3%
332
+ Test duration: 45.2s
333
+
334
+ Performance:
335
+ Avg latency: 123.4ms
336
+ Median latency: 98.7ms
337
+ p95 latency: 342.1ms
338
+ Max latency: 2,341.0ms
339
+ Requests/sec: 27.6
340
+
341
+ Category breakdown:
342
+ standard: 25 datasets, 94.2% avg success
343
+ edge_case: 25 datasets, 76.8% avg success
344
+ performance: 25 datasets, 91.1% avg success
345
+ model_comparison: 25 datasets, 89.3% avg success
346
+ ```
347
+
348
+ ## Common Issues
349
+
350
+ **Port 8000 already in use:**
351
+ ```bash
352
+ # Find what's using it
353
+ lsof -i :8000
354
+
355
+ # Or just use a different port
356
+ uvicorn main:app --port 8080
357
+ ```
358
+
359
+ **Model not loading:**
360
+ - Check the path: models should be in `models/<org>/<model-name>/`
361
+ - If you're trying to run the example ResNet-based model, make sure you ran `make download` to fetch the model weights.
362
+ - Check logs for the exact error
363
+
364
+ **Slow inference:**
365
+ - Inference runs on CPU by default
366
+ - For GPU: install CUDA PyTorch and modify service to use GPU device
367
+ - Consider using smaller models or quantization
368
+
369
+ ## License
370
+
371
+ Apache 2.0
app/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ ML Inference Service
3
+
4
+ A FastAPI-based web service for machine learning model inference.
5
+ """
6
+
7
+ __version__ = "0.1.0"
app/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """API layer for the ML inference service."""
app/api/controllers.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """API controllers for request handling and validation."""
2
+
3
+ import asyncio
4
+ from fastapi import HTTPException
5
+
6
+ from app.core.logging import logger
7
+ from app.services.base import InferenceService
8
+ from app.api.models import ImageRequest, PredictionResponse
9
+
10
+
11
+ class PredictionController:
12
+ """Controller for prediction endpoints."""
13
+
14
+ @staticmethod
15
+ async def predict(
16
+ request: ImageRequest,
17
+ service: InferenceService
18
+ ) -> PredictionResponse:
19
+ """Run inference using the configured service."""
20
+ try:
21
+ if not service or not service.is_loaded:
22
+ raise HTTPException(503, "Service not available")
23
+
24
+ if not request.image.mediaType.startswith('image/'):
25
+ raise HTTPException(400, f"Invalid media type: {request.image.mediaType}")
26
+
27
+ return await asyncio.to_thread(service.predict, request)
28
+
29
+ except HTTPException:
30
+ raise
31
+ except ValueError as e:
32
+ logger.error(f"Invalid input: {e}")
33
+ raise HTTPException(400, str(e))
34
+ except Exception as e:
35
+ logger.error(f"Prediction failed: {e}")
36
+ raise HTTPException(500, "Internal server error")
app/api/models.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models for request/response validation.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import base64
8
+ import enum
9
+ import io
10
+ import typing
11
+ from typing import Literal, Optional
12
+
13
+ import numpy as np
14
+ import pydantic
15
+ from PIL import Image
16
+
17
+ if typing.TYPE_CHECKING:
18
+ from numpy.typing import NDArray
19
+
20
+
21
+ class ImageData(pydantic.BaseModel):
22
+ """Image data model for base64 encoded images."""
23
+ mediaType: str = pydantic.Field(
24
+ description="The IETF Media Type (MIME type) of the data"
25
+ )
26
+ data: str = pydantic.Field(
27
+ description="A base64 string encoding of the data.",
28
+ # Canonical base64 encoding
29
+ # https://stackoverflow.com/a/64467300/3709935
30
+ pattern=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/][AQgw]==|[A-Za-z0-9+/]{2}[AEIMQUYcgkosw048]=)?$",
31
+ )
32
+
33
+
34
+ class BinaryMask(pydantic.BaseModel):
35
+ """A bit mask indicating which pixels are manipulated / synthesized. A
36
+ pixel value of ``0`` means "no detection", and a value of ``1`` means
37
+ "detection".
38
+
39
+ The mask data must be encoded in PNG format, so that typical mask data is
40
+ compressed effectively. The PNG encoding **should** use "bilevel" mode for
41
+ maximum compactness. You can use the ``BinaryMask.from_numpy()``
42
+ function to convert a 0-1 numpy array to a BinaryMask.
43
+ """
44
+ mediaType: Literal["image/png"] = pydantic.Field(
45
+ description="The IETF Media Type (MIME type) of the data."
46
+ )
47
+ data: str = pydantic.Field(
48
+ description="A base64 string encoding of the data.",
49
+ # Canonical base64 encoding
50
+ # https://stackoverflow.com/a/64467300/3709935
51
+ pattern=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/][AQgw]==|[A-Za-z0-9+/]{2}[AEIMQUYcgkosw048]=)?$",
52
+ )
53
+
54
+ @staticmethod
55
+ def from_numpy(mask: NDArray[np.uint8]) -> BinaryMask:
56
+ """Convert a 0-1 numpy array to a BinaryMask.
57
+
58
+ The numpy data must be in row-major order. That means the first
59
+ dimension corresponds to **height** and the second dimension corresponds
60
+ to **width**.
61
+ """
62
+ # Convert to "L" (grayscale) then "1" (bilevel) for compact binary representation
63
+ mask_img = Image.fromarray(mask * 255, mode="L").convert("1", dither=None)
64
+ mask_img_buffer = io.BytesIO()
65
+ mask_img.save(mask_img_buffer, format="png")
66
+ mask_data = base64.b64encode(mask_img_buffer.getbuffer()).decode("utf-8")
67
+ return BinaryMask(mediaType="image/png", data=mask_data)
68
+
69
+
70
+ class ImageRequest(pydantic.BaseModel):
71
+ """Request model for image classification."""
72
+ image: ImageData
73
+
74
+
75
+ class Labels(enum.IntEnum):
76
+ Natural = 0
77
+ FullySynthesized = 1
78
+ LocallyEdited = 2
79
+ LocallySynthesized = 3
80
+
81
+
82
+ class PredictionResponse(pydantic.BaseModel):
83
+ """Response model for synthetic image classification results.
84
+
85
+ Detector models will be scored primarily on their ability to classify the
86
+ entire image into 1 of the 4 label categories::
87
+
88
+ 0: (Natural) The image is natural / unaltered.
89
+ 1: (FullySynthesized) The entire image was synthesized by e.g., a
90
+ generative image model.
91
+ 2: (LocallyEdited) The image is a natural image where a portion has
92
+ been edited using traditional photo editing techniques such as
93
+ splicing.
94
+ 3: (LocallySynthesized) The image is a natural image where a portion
95
+ has been replaced by synthesized content.
96
+ """
97
+
98
+ logprobs: list[float] = pydantic.Field(
99
+ description="The log-probabilities for each of the 4 possible labels.",
100
+ min_length=4,
101
+ max_length=4,
102
+ )
103
+
104
+ localizationMask: Optional[BinaryMask] = pydantic.Field(
105
+ description="A bit mask localizing predicted edits. Models that are"
106
+ " not capable of localization may omit this field. It may also be"
107
+ " omitted if the predicted label is ``0`` or ``1``, in which case the"
108
+ " mask will be assumed to be all 0's or all 1's, as appropriate."
109
+ )
app/api/routes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """API route definitions."""
app/api/routes/prediction.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prediction API routes."""
2
+
3
+ from fastapi import APIRouter, Depends
4
+
5
+ from app.api.controllers import PredictionController
6
+ from app.api.models import ImageRequest, PredictionResponse
7
+ from app.core.dependencies import get_inference_service
8
+ from app.services.base import InferenceService
9
+
10
+ router = APIRouter()
11
+
12
+
13
+ @router.post("/predict", response_model=PredictionResponse)
14
+ async def predict(
15
+ request: ImageRequest,
16
+ service: InferenceService = Depends(get_inference_service)
17
+ ):
18
+ """
19
+ Run inference on base64-encoded image.
20
+
21
+ Returns prediction, confidence, predicted label, model name, and media type.
22
+ """
23
+ return await PredictionController.predict(request, service)
app/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Core utilities and configurations."""
app/core/app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application factory and core infrastructure."""
2
+
3
+ import asyncio
4
+ import warnings
5
+ from contextlib import asynccontextmanager
6
+ from typing import AsyncGenerator
7
+
8
+ from fastapi import FastAPI
9
+ from pydantic import Field
10
+ from pydantic_settings import BaseSettings
11
+
12
+ from app.core.logging import logger
13
+ from app.core.dependencies import set_inference_service
14
+ from app.services.inference import ResNetInferenceService
15
+ from app.services.inference_clip import CLIPInferenceService
16
+
17
+ from app.services.inference_dino import DINOInferenceService
18
+ from app.api.routes import prediction
19
+
20
+
21
+
22
+
23
+
24
+ class Settings(BaseSettings):
25
+ """Application settings. Override via environment variables or .env file."""
26
+
27
+ app_name: str = Field(default="ML Inference Service")
28
+ app_version: str = Field(default="0.1.0")
29
+ debug: bool = Field(default=False)
30
+ host: str = Field(default="0.0.0.0")
31
+ port: int = Field(default=8000)
32
+
33
+ class Config:
34
+ env_file = ".env"
35
+
36
+
37
+ settings = Settings()
38
+
39
+
40
+ @asynccontextmanager
41
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
42
+ """Application lifecycle: startup/shutdown."""
43
+ logger.info("Starting ML Inference Service...")
44
+
45
+ try:
46
+ with warnings.catch_warnings():
47
+ warnings.filterwarnings("ignore", category=FutureWarning)
48
+
49
+ # Replace ResNetInferenceService with your own implementation
50
+ # service = ResNetInferenceService(model_name="microsoft/resnet-18")
51
+ # service = CLIPInferenceService(model_name="clip")
52
+ service = DINOInferenceService(model_name = "dino")
53
+ await asyncio.to_thread(service.load_model)
54
+ set_inference_service(service)
55
+
56
+ logger.info("Startup completed successfully")
57
+
58
+ except Exception as e:
59
+
60
+ logger.error(f"Startup failed: {e}")
61
+ raise
62
+
63
+ yield
64
+
65
+ logger.info("Shutting down...")
66
+
67
+
68
+ def create_app() -> FastAPI:
69
+ """Create and configure FastAPI application."""
70
+ app = FastAPI(
71
+ title=settings.app_name,
72
+ description="ML inference service for image classification",
73
+ version=settings.app_version,
74
+ debug=settings.debug,
75
+ lifespan=lifespan
76
+ )
77
+
78
+ app.include_router(prediction.router)
79
+
80
+ return app
app/core/dependencies.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dependency injection for services."""
2
+
3
+ from typing import Optional
4
+ from app.services.base import InferenceService
5
+
6
+ _inference_service: Optional[InferenceService] = None
7
+
8
+
9
+ def get_inference_service() -> Optional[InferenceService]:
10
+ """Get inference service for dependency injection."""
11
+ return _inference_service
12
+
13
+
14
+ def set_inference_service(service: InferenceService) -> None:
15
+ """Set inference service. Called internally during startup."""
16
+ global _inference_service
17
+ _inference_service = service
app/core/logging.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logging configuration."""
2
+
3
+ import logging
4
+ import sys
5
+
6
+
7
+ def setup_logging(logger_name: str = "ML Inference Service") -> logging.Logger:
8
+ """Setup and configure logger."""
9
+ logger = logging.getLogger(logger_name)
10
+
11
+ if logger.handlers:
12
+ return logger
13
+
14
+ logger.setLevel(logging.INFO)
15
+ handler = logging.StreamHandler(sys.stdout)
16
+ handler.setLevel(logging.INFO)
17
+ formatter = logging.Formatter(
18
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
19
+ )
20
+ handler.setFormatter(formatter)
21
+ logger.addHandler(handler)
22
+
23
+ return logger
24
+
25
+
26
+ logger = setup_logging()
app/services/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Business logic services."""
app/services/base.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Abstract base class for ML inference services."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Generic, TypeVar
5
+ from pydantic import BaseModel
6
+
7
+ TRequest = TypeVar('TRequest', bound=BaseModel)
8
+ TResponse = TypeVar('TResponse', bound=BaseModel)
9
+
10
+
11
+ class InferenceService(ABC, Generic[TRequest, TResponse]):
12
+ """
13
+ Base class for inference services. Subclass this to integrate your model.
14
+ """
15
+
16
+ @abstractmethod
17
+ def load_model(self) -> None:
18
+ """Load model weights and processors. Called once at startup."""
19
+ pass
20
+
21
+ @abstractmethod
22
+ def predict(self, request: TRequest) -> TResponse:
23
+ """Run inference and return typed response."""
24
+ pass
25
+
26
+ @property
27
+ @abstractmethod
28
+ def is_loaded(self) -> bool:
29
+ """Check if model is loaded and ready."""
30
+ pass
app/services/inference.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ResNet inference service implementation."""
2
+
3
+ import base64
4
+ import os
5
+ from io import BytesIO
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
11
+
12
+ from app.core.logging import logger
13
+ from app.services.base import InferenceService
14
+ from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse
15
+
16
+
17
+ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
18
+ """ResNet-18 inference service for image classification."""
19
+
20
+ def __init__(self, model_name: str = "microsoft/resnet-18"):
21
+ self.model_name = model_name
22
+ self.model = None
23
+ self.processor = None
24
+ self._is_loaded = False
25
+ self.model_path = os.path.join("models", model_name)
26
+ logger.info(f"Initializing ResNet service: {self.model_path}")
27
+
28
+ def load_model(self) -> None:
29
+ if self._is_loaded:
30
+ return
31
+
32
+ if not os.path.exists(self.model_path):
33
+ raise FileNotFoundError(f"Model not found: {self.model_path}")
34
+
35
+ config_path = os.path.join(self.model_path, "config.json")
36
+ if not os.path.exists(config_path):
37
+ raise FileNotFoundError(f"Config not found: {config_path}")
38
+
39
+ logger.info(f"Loading model from {self.model_path}")
40
+
41
+ import warnings
42
+ with warnings.catch_warnings():
43
+ warnings.filterwarnings("ignore", category=FutureWarning)
44
+ self.processor = AutoImageProcessor.from_pretrained(
45
+ self.model_path, local_files_only=True
46
+ )
47
+ self.model = ResNetForImageClassification.from_pretrained(
48
+ self.model_path, local_files_only=True
49
+ )
50
+ assert self.model is not None
51
+
52
+ self._is_loaded = True
53
+ logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
54
+
55
+ def predict(self, request: ImageRequest) -> PredictionResponse:
56
+ if not self.is_loaded:
57
+ raise RuntimeError("model is not loaded")
58
+ assert self.processor is not None
59
+ assert self.model is not None
60
+
61
+ image_data = base64.b64decode(request.image.data)
62
+ image = Image.open(BytesIO(image_data))
63
+
64
+ if image.mode != 'RGB':
65
+ image = image.convert('RGB')
66
+
67
+ inputs = self.processor(image, return_tensors="pt")
68
+
69
+ with torch.no_grad():
70
+ logits = self.model(**inputs).logits.squeeze() # pyright: ignore
71
+
72
+ # Convert the ImageNet output vector of dimension 1000 to the expected
73
+ # output format. This is for demonstration purposes
74
+ # and obviously will not perform well on the actual task.
75
+ logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist()
76
+
77
+ # Dummy localization mask: a rectangle approximately in the middle
78
+ x = image.width // 3
79
+ y = image.height // 3
80
+ # Row-major order
81
+ mask = np.zeros((image.height, image.width), dtype=np.uint8)
82
+ mask[y:(2*y), x:(2*x)] = 1
83
+ mask_obj = BinaryMask.from_numpy(mask)
84
+
85
+ return PredictionResponse(
86
+ logprobs=logprobs,
87
+ localizationMask=mask_obj,
88
+ )
89
+
90
+ @property
91
+ def is_loaded(self) -> bool:
92
+ return self._is_loaded
app/services/inference_clip.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ResNet inference service implementation."""
2
+
3
+ import base64
4
+ import os
5
+ from io import BytesIO
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torchvision.transforms as transforms
10
+ from PIL import Image
11
+ from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
12
+
13
+ from app.core.logging import logger
14
+ from app.services.base import InferenceService
15
+ from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse
16
+
17
+ from app.services.models.alignment_pretrained.unet import UNetImageDecoder
18
+ from app.services.models.CLIPSvD import CLIPSvD
19
+ from app.services.models.alignment_pretrained.model_with_bce_images_blip import MMModerator
20
+
21
+
22
+ class CLIPDINOProcessor:
23
+ """Image processor for CLIP and DINO models with 224x224 resizing and normalization."""
24
+
25
+ def __init__(self, image_size: int = 224):
26
+ self.image_size = image_size
27
+ # Standard ImageNet normalization used by both CLIP and DINO
28
+ self.transform = transforms.Compose([
29
+ transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(
32
+ mean=[0.485, 0.456, 0.406], # ImageNet mean
33
+ std=[0.229, 0.224, 0.225] # ImageNet std
34
+ )
35
+ ])
36
+
37
+ def __call__(self, image: Image.Image, return_tensors: str = "pt"):
38
+ """
39
+ Process an image for CLIP/DINO input.
40
+
41
+ Args:
42
+ image: PIL Image object
43
+ return_tensors: Format of returned tensors (default: "pt" for PyTorch)
44
+
45
+ Returns:
46
+ Dictionary with 'pixel_values' key containing the processed tensor
47
+ """
48
+ if not isinstance(image, Image.Image):
49
+ raise ValueError("Input must be a PIL Image")
50
+
51
+ # Apply transforms
52
+ pixel_values = self.transform(image)
53
+
54
+ # Add batch dimension if needed
55
+ if pixel_values.dim() == 3:
56
+ pixel_values = pixel_values.unsqueeze(0)
57
+
58
+ return {"pixel_values": pixel_values}
59
+
60
+
61
+ def create_vision_encoder():
62
+ REPO_DIR = "/media/NAS/USERS/shahid/MultimodalAudioVisualModerator/dinov3"
63
+ model = CLIPSvD()
64
+ return model
65
+
66
+
67
+
68
+ class CLIPInferenceService(InferenceService[ImageRequest, PredictionResponse]):
69
+ """ResNet-18 inference service for image classification."""
70
+
71
+ def __init__(self, model_name: str = "microsoft/resnet-18"):
72
+ self.model_name = model_name
73
+ # self.model = None
74
+ self.processor = CLIPDINOProcessor(image_size=224)
75
+ self._is_loaded = False
76
+ self.model_path = os.path.join("models", model_name)
77
+
78
+ pretraining = False
79
+ num_classes = 4
80
+ self.vision_encoder = create_vision_encoder()
81
+ # self.vision_encoder.to(device=device, dtype=torch.float32)
82
+ self.unet_decoder = UNetImageDecoder(
83
+ num_patches=256, # 7 × 7 grid (ViT-B/32)
84
+ token_dim=1024, # ViT-B/32 embedding dim
85
+ out_channels=3, # mask or 3 for RGB
86
+ base_channels=256,
87
+ img_size=256,
88
+ grid_hw=(16, 16) # explicitly set to match patch grid
89
+ )
90
+ self.model = MMModerator(pretraining=pretraining,vision_encoder=self.vision_encoder, unet_decoder=self.unet_decoder, num_classes=num_classes)
91
+ logger.info(f"Initializing CLIP service: {self.model_path}")
92
+
93
+ def load_model(self) -> None:
94
+ if self._is_loaded:
95
+ return
96
+
97
+ if not os.path.exists(self.model_path):
98
+ raise FileNotFoundError(f"Model not found: {self.model_path}")
99
+
100
+ config_path = os.path.join(self.model_path, "config.json")
101
+ # if not os.path.exists(config_path):
102
+ # raise FileNotFoundError(f"Config not found: {config_path}")
103
+
104
+ logger.info(f"Loading model from {self.model_path}")
105
+ checkpoint_path = os.path.join(self.model_path, "model_state.pt")
106
+ checkpoint_path_encoder = os.path.join(self.model_path, "model_state_encoder.pt")
107
+ checkpoint_path_decoder = os.path.join(self.model_path, "model_state_decoder.pt")
108
+ raw = torch.load(checkpoint_path)
109
+ raw_encoder = torch.load(checkpoint_path_encoder)
110
+ raw_decoder = torch.load(checkpoint_path_decoder)
111
+ sd = raw.get("model_state_dict", raw)
112
+ new_sd = {}
113
+ for k, v in sd.items():
114
+ new_key = k.replace("module.", "")
115
+ new_sd[new_key] = v
116
+
117
+ self.model.load_state_dict(new_sd)
118
+ sd_encoder = raw_encoder.get("model_state_dict", raw_encoder)
119
+ new_sd_encoder = {}
120
+ for k, v in sd_encoder.items():
121
+ new_key = k.replace("module.", "")
122
+ new_sd_encoder[new_key] = v
123
+ self.vision_encoder.load_state_dict(new_sd_encoder)
124
+ sd_decoder = raw_decoder.get("model_state_dict", raw_decoder)
125
+ new_sd_decoder = {}
126
+ for k, v in sd_decoder.items():
127
+ new_key = k.replace("module.", "")
128
+ new_sd_decoder[new_key] = v
129
+ self.unet_decoder.load_state_dict(new_sd_decoder) # strict=True by default
130
+
131
+ self._is_loaded = True
132
+ # logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
133
+
134
+ def predict(self, request: ImageRequest) -> PredictionResponse:
135
+ if not self.is_loaded:
136
+ raise RuntimeError("model is not loaded")
137
+
138
+ assert self.processor is not None
139
+
140
+ assert self.model is not None
141
+ image_data = base64.b64decode(request.image.data)
142
+ image = Image.open(BytesIO(image_data))
143
+ if image.mode != 'RGB':
144
+ image = image.convert('RGB')
145
+
146
+ processed = self.processor(image, return_tensors="pt")
147
+ pixel_values = processed["pixel_values"]
148
+
149
+ self.model.eval()
150
+ with torch.no_grad():
151
+ # logits, losses, label, image_recon, overlay =
152
+ logits_cls,logits, losses, labels_expanded,data_labels_expanded, image_recon, overlay = self.model(images=pixel_values)
153
+
154
+ logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist()[0]
155
+
156
+ # x = image.width // 3
157
+ # y = image.height // 3
158
+ # # Row-major order
159
+ # mask = np.zeros((image.height, image.width), dtype=np.uint8)
160
+ # mask[y:(2*y), x:(2*x)] = 1
161
+ # mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
162
+ # mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
163
+ print(logprobs)
164
+ return PredictionResponse(
165
+ logprobs=logprobs,
166
+ localizationMask=None,
167
+ )
168
+
169
+ @property
170
+ def is_loaded(self) -> bool:
171
+ return self._is_loaded
app/services/inference_dino.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ResNet inference service implementation."""
2
+
3
+ import base64
4
+ import os
5
+ from io import BytesIO
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torchvision.transforms as transforms
10
+ from PIL import Image
11
+ from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
12
+
13
+ from app.core.logging import logger
14
+ from app.services.base import InferenceService
15
+ from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse
16
+
17
+ from app.services.models.alignment_pretrained.unet import UNetImageDecoder
18
+
19
+ from app.services.models.DinoLORA import DINOEncoderLoRA
20
+ from app.services.models.alignment_pretrained.model_with_bce_images_dino import MMModerator
21
+
22
+
23
+ class DINODINOProcessor:
24
+ """Image processor for DINO and DINO models with 224x224 resizing and normalization."""
25
+
26
+ def __init__(self, image_size: int = 512):
27
+ self.image_size = image_size
28
+ # Standard ImageNet normalization used by both DINO and DINO
29
+ self.transform = transforms.Compose([
30
+ transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(
33
+ mean=[0.485, 0.456, 0.406], # ImageNet mean
34
+ std=[0.229, 0.224, 0.225] # ImageNet std
35
+ )
36
+ ])
37
+
38
+ def __call__(self, image: Image.Image, return_tensors: str = "pt"):
39
+ """
40
+ Process an image for DINO/DINO input.
41
+
42
+ Args:
43
+ image: PIL Image object
44
+ return_tensors: Format of returned tensors (default: "pt" for PyTorch)
45
+
46
+ Returns:
47
+ Dictionary with 'pixel_values' key containing the processed tensor
48
+ """
49
+ if not isinstance(image, Image.Image):
50
+ raise ValueError("Input must be a PIL Image")
51
+
52
+ # Apply transforms
53
+ pixel_values = self.transform(image)
54
+
55
+ # Add batch dimension if needed
56
+ if pixel_values.dim() == 3:
57
+ pixel_values = pixel_values.unsqueeze(0)
58
+
59
+ return {"pixel_values": pixel_values}
60
+
61
+
62
+ def create_vision_encoder():
63
+ REPO_DIR = "/media/NAS/USERS/shahid/MultimodalAudioVisualModerator/dinov3"
64
+
65
+
66
+ encoder = torch.hub.load(REPO_DIR, 'dinov3_vitb16', source='local', weights="https://dinov3.llamameta.net/dinov3_vitb16/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoib3VpbXR2cHlhZXE5c2JwajNucnN3aWF2IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvZGlub3YzLmxsYW1hbWV0YS5uZXRcLyoiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE3NTk5MDI0NDF9fX1dfQ__&Signature=GIYJR4%7ESJVx0gkcm7lgAvDljIfpR30MXgWb2VpCqbDeVpnwjn97k%7EOcPeGF-lkR0q1Sn3Iw5Y3iYWqspcpPoDJ4FXUmMKhWtd-m00HO73Aknq2kyrKVMBpzwQB-k-2zZe7okJfXTj46EWbzu9mNcxt%7ErDPe7phQpRJi8Dleida1BJ823oXFx8d7oRSa4NDSzT2TNXqNNZ8ux7N0aDfdT9dupEeEr4AP06LhYB2I7kF%7Ef4bvKQsKnlPMVDAADyYG9nQ7HqAW41LaWZtR-BrDGm%7ESNu-6L44cUVnk3qEPVRMQB4GW7ixRGGhtr37F6HVz%7EKilrCpivFD6ej4reNUWaGQ__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=829796286371500")
67
+ model = DINOEncoderLoRA(encoder, r=16, emb_dim=1024, use_lora=True)
68
+ return model
69
+
70
+
71
+
72
+ class DINOInferenceService(InferenceService[ImageRequest, PredictionResponse]):
73
+ """ResNet-18 inference service for image classification."""
74
+
75
+ def __init__(self, model_name: str = "microsoft/resnet-18"):
76
+ self.model_name = model_name
77
+ # self.model = None
78
+ self.processor = DINODINOProcessor(image_size=512)
79
+ self._is_loaded = False
80
+ self.model_path = os.path.join("models", model_name)
81
+
82
+ pretraining = False
83
+ num_classes = 4
84
+ self.vision_encoder = create_vision_encoder()
85
+ # self.vision_encoder.to(device=device, dtype=torch.float32)
86
+
87
+ self.unet_decoder = UNetImageDecoder(
88
+ num_patches=1024, # MUST match N=1024
89
+ token_dim=768, # because tokens are [B,1024,768]
90
+ out_channels=1, # RGB reconstructed output
91
+ base_channels=128, # recommended for 512px
92
+ img_size=512, # image resolution
93
+ grid_hw=(32, 32) # MUST match N=1024
94
+ )
95
+ self.model = MMModerator(pretraining=pretraining,vision_encoder=self.vision_encoder, unet_decoder=self.unet_decoder, num_classes=num_classes)
96
+ self.model.eval()
97
+ logger.info(f"Initializing DINO service: {self.model_path}")
98
+
99
+ def load_model(self) -> None:
100
+ if self._is_loaded:
101
+ return
102
+
103
+ if not os.path.exists(self.model_path):
104
+ raise FileNotFoundError(f"Model not found: {self.model_path}")
105
+
106
+ config_path = os.path.join(self.model_path, "config.json")
107
+ # if not os.path.exists(config_path):
108
+ # raise FileNotFoundError(f"Config not found: {config_path}")
109
+
110
+ logger.info(f"Loading model from {self.model_path}")
111
+ checkpoint_path = os.path.join(self.model_path, "model_state.pt")
112
+ checkpoint_path_encoder = os.path.join(self.model_path, "model_state_encoder.pt")
113
+ checkpoint_path_decoder = os.path.join(self.model_path, "model_state_decoder.pt")
114
+ raw = torch.load(checkpoint_path)
115
+ raw_encoder = torch.load(checkpoint_path_encoder)
116
+ raw_decoder = torch.load(checkpoint_path_decoder)
117
+ sd = raw.get("model_state_dict", raw)
118
+ new_sd = {}
119
+ for k, v in sd.items():
120
+ new_key = k.replace("module.", "")
121
+ new_sd[new_key] = v
122
+
123
+ self.model.load_state_dict(new_sd)
124
+ sd_encoder = raw_encoder.get("model_state_dict", raw_encoder)
125
+ new_sd_encoder = {}
126
+ for k, v in sd_encoder.items():
127
+ new_key = k.replace("module.", "")
128
+ new_sd_encoder[new_key] = v
129
+ self.vision_encoder.load_state_dict(new_sd_encoder)
130
+ sd_decoder = raw_decoder.get("model_state_dict", raw_decoder)
131
+ new_sd_decoder = {}
132
+ for k, v in sd_decoder.items():
133
+ new_key = k.replace("module.", "")
134
+ new_sd_decoder[new_key] = v
135
+ self.unet_decoder.load_state_dict(new_sd_decoder) # strict=True by default
136
+
137
+ self._is_loaded = True
138
+ # logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
139
+
140
+ def predict(self, request: ImageRequest) -> PredictionResponse:
141
+ if not self.is_loaded:
142
+ raise RuntimeError("model is not loaded")
143
+
144
+ assert self.processor is not None
145
+
146
+ assert self.model is not None
147
+ image_data = base64.b64decode(request.image.data)
148
+ image = Image.open(BytesIO(image_data))
149
+ if image.mode != 'RGB':
150
+ image = image.convert('RGB')
151
+
152
+ processed = self.processor(image, return_tensors="pt")
153
+ pixel_values = processed["pixel_values"]
154
+
155
+ self.model.eval()
156
+ with torch.no_grad():
157
+ # logits, losses, label, image_recon, overlay =
158
+ logits,logits_multi_cls, losses, label,data_label, image_recon, overlay, shuffled_images, gt_masks = self.model(images=pixel_values)
159
+
160
+ logprobs = torch.nn.functional.log_softmax(logits_multi_cls[:len(Labels)]).tolist()[0]
161
+
162
+ # x = image.width // 3
163
+ # y = image.height // 3
164
+ # # Row-major order
165
+ # mask = np.zeros((image.height, image.width), dtype=np.uint8)
166
+ # mask[y:(2*y), x:(2*x)] = 1
167
+ # mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
168
+ # mask_obj = BinaryMask.from_numpy(image_recon.cpu().numpy())
169
+ print(logprobs)
170
+ return PredictionResponse(
171
+ logprobs=logprobs,
172
+ localizationMask=None,
173
+ )
174
+
175
+ @property
176
+ def is_loaded(self) -> bool:
177
+ return self._is_loaded
cat.json ADDED
The diff for this file is too large to render. See raw diff
 
challenge-cli.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2025 UL Research Institutes
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import functools
5
+ import time
6
+ from datetime import datetime, timedelta, timezone
7
+ from pathlib import Path
8
+
9
+ import click
10
+ import httpx
11
+
12
+ from dyff.client import Client, errors
13
+ from dyff.schema.platform import *
14
+ from dyff.schema.requests import *
15
+
16
+ from app.api.models import PredictionResponse
17
+
18
+ # ----------------------------------------------------------------------------
19
+
20
+
21
+ def _wait_for_status(
22
+ get_entity_fn, target_status: str | list[str], *, timeout: timedelta
23
+ ) -> str:
24
+ if isinstance(target_status, str):
25
+ target_status = [target_status]
26
+ then = datetime.now(timezone.utc)
27
+ while True:
28
+ try:
29
+ status = get_entity_fn().status
30
+ if status in target_status:
31
+ return status
32
+ except errors.HTTPError as ex:
33
+ if ex.status != 404:
34
+ raise
35
+ except httpx.HTTPStatusError as ex:
36
+ if ex.response.status_code != 404:
37
+ raise
38
+ if (datetime.now(timezone.utc) - then) >= timeout:
39
+ break
40
+ time.sleep(5)
41
+ raise AssertionError("timeout")
42
+
43
+
44
+ def _common_options(f):
45
+ @click.option(
46
+ "--account",
47
+ type=str,
48
+ required=True,
49
+ help="Your account ID",
50
+ metavar="ID",
51
+ )
52
+ @functools.wraps(f)
53
+ def wrapper(*args, **kwargs):
54
+ return f(*args, **kwargs)
55
+ return wrapper
56
+
57
+
58
+ @click.group()
59
+ def cli():
60
+ pass
61
+
62
+
63
+ @cli.command()
64
+ @_common_options
65
+ @click.option(
66
+ "--name",
67
+ type=str,
68
+ required=True,
69
+ help="The name of your detector model. For display and querying purposes only.",
70
+ )
71
+ @click.option(
72
+ "--image",
73
+ type=str,
74
+ default=None,
75
+ help="The Docker image to upload (e.g., 'some/image:latest')."
76
+ " Must exist in your local Docker deamon."
77
+ " Required if --artifact is not specified.",
78
+ )
79
+ @click.option(
80
+ "--endpoint",
81
+ type=str,
82
+ default="predict",
83
+ help="The endpoint to call on your service to make a prediction.",
84
+ )
85
+ @click.option(
86
+ "--volume",
87
+ type=click.Path(exists=True, file_okay=False, readable=True, resolve_path=True, path_type=Path),
88
+ default=None,
89
+ help="A local directory path containing files to upload and mount in the running Docker container."
90
+ " You should use this if your submission includes large files like neural network weights."
91
+ )
92
+ @click.option(
93
+ "--volume-mount",
94
+ type=click.Path(exists=False, path_type=Path),
95
+ default=None,
96
+ help="The path to mount your uploaded directory in the running Docker container."
97
+ " Must be an absolute path."
98
+ " Required if --volume is specified.")
99
+ @click.option(
100
+ "--artifact",
101
+ "artifact_id",
102
+ type=str,
103
+ default=None,
104
+ help="The ID of the Artifact (i.e., Docker image) to use in the submission, if it already exists."
105
+ " You can pass the artifact.id from a previous invocation.",
106
+ metavar="ID",
107
+ )
108
+ @click.option(
109
+ "--model",
110
+ "model_id",
111
+ type=str,
112
+ default=None,
113
+ help="The ID of the Model (i.e., neural network weights) to use in the submission, if it already exists."
114
+ " You can pass the model.id from a previous invocation.",
115
+ metavar="ID",
116
+ )
117
+ @click.option(
118
+ "--gpu",
119
+ is_flag=True,
120
+ default=False,
121
+ help="Request a GPU (NVIDIA L4) for the inference service.",
122
+ )
123
+ def upload_submission(
124
+ account: str,
125
+ name: str,
126
+ image: str | None,
127
+ endpoint: str,
128
+ volume: Path | None,
129
+ volume_mount: Path | None,
130
+ artifact_id: str | None,
131
+ model_id: str | None,
132
+ gpu: bool,
133
+ ) -> None:
134
+ dyffapi = Client()
135
+
136
+ # Upload the image
137
+ if artifact_id is None:
138
+ # Create an Artifact resource
139
+ click.echo(f"creating Artifact ... {account}")
140
+ artifact = dyffapi.artifacts.create(ArtifactCreateRequest(account=account))
141
+ click.echo(f"artifact.id: \"{artifact.id}\"")
142
+ _wait_for_status(
143
+ lambda: dyffapi.artifacts.get(artifact.id),
144
+ "WaitingForUpload",
145
+ timeout=timedelta(seconds=30),
146
+ )
147
+
148
+ # Push the image from the local Docker daemon
149
+ click.echo("pushing Artifact ...")
150
+ dyffapi.artifacts.push(artifact, source=f"docker-daemon:{image}")
151
+ time.sleep(5)
152
+
153
+ # Indicate that we're done pushing
154
+ dyffapi.artifacts.finalize(artifact.id)
155
+ _wait_for_status(
156
+ lambda: dyffapi.artifacts.get(artifact.id),
157
+ "Ready",
158
+ timeout=timedelta(seconds=30),
159
+ )
160
+
161
+ click.echo("... done")
162
+ else:
163
+ artifact = dyffapi.artifacts.get(artifact_id)
164
+ assert artifact is not None
165
+
166
+ model: Model | None = None
167
+ if model_id is None:
168
+ if volume is not None:
169
+ if volume_mount is None:
170
+ raise click.UsageError("--volume-mount is required when --volume is used")
171
+
172
+ click.echo("creating Model from local directory ...")
173
+
174
+ model = dyffapi.models.create_from_volume(
175
+ volume, name="model_volume", account=account, resources=ModelResources()
176
+ )
177
+ click.echo(f"model.id: \"{model.id}\"")
178
+ _wait_for_status(
179
+ lambda: dyffapi.models.get(model.id),
180
+ "WaitingForUpload",
181
+ timeout=timedelta(seconds=30),
182
+ )
183
+
184
+ click.echo("uploading Model ...")
185
+ dyffapi.models.upload_volume(model, volume)
186
+ _wait_for_status(
187
+ lambda: dyffapi.models.get(model.id),
188
+ "Ready",
189
+ timeout=timedelta(seconds=30),
190
+ )
191
+
192
+ click.echo("... done")
193
+ else:
194
+ model = None
195
+ else:
196
+ model = dyffapi.models.get(model_id)
197
+ assert model is not None
198
+
199
+ # Create a runnable InferenceService
200
+ if volume_mount is not None:
201
+ if model is None:
202
+ raise click.UsageError("--volume-mount requires --volume or --model")
203
+ if not volume_mount.is_absolute():
204
+ raise click.UsageError("--volume-mount must be an absolute path")
205
+ volumeMounts=[
206
+ VolumeMount(
207
+ kind=VolumeMountKind.data,
208
+ name="model",
209
+ mountPath=volume_mount,
210
+ data=VolumeMountData(
211
+ source=EntityIdentifier.of(model),
212
+ ),
213
+ ),
214
+ ]
215
+ else:
216
+ volumeMounts = None
217
+
218
+ accelerator: Accelerator | None = None
219
+ if gpu:
220
+ accelerator = Accelerator(
221
+ kind="GPU",
222
+ gpu=AcceleratorGPU(
223
+ hardwareTypes=["nvidia.com/gpu-l4"],
224
+ count=1,
225
+ ),
226
+ )
227
+
228
+ # Don't change this
229
+ service_request = InferenceServiceCreateRequest(
230
+ account=account,
231
+ name=name,
232
+ model=None,
233
+ runner=InferenceServiceRunner(
234
+ kind=InferenceServiceRunnerKind.CONTAINER,
235
+ imageRef=EntityIdentifier.of(artifact),
236
+ resources=ModelResources(),
237
+ volumeMounts=volumeMounts,
238
+ accelerator=accelerator,
239
+ ),
240
+ interface=InferenceInterface(
241
+ endpoint=endpoint,
242
+ outputSchema=DataSchema.make_output_schema(PredictionResponse),
243
+ ),
244
+ )
245
+ click.echo("creating InferenceService ...")
246
+ service = dyffapi.inferenceservices.create(service_request)
247
+ click.echo(f"service.id: \"{service.id}\"")
248
+ click.echo("... done")
249
+
250
+
251
+ @cli.command()
252
+ @_common_options
253
+ @click.option(
254
+ "--task",
255
+ "task_id",
256
+ type=str,
257
+ required=True,
258
+ help="The Task ID to submit to.",
259
+ metavar="ID",
260
+ )
261
+ @click.option(
262
+ "--team",
263
+ "team_id",
264
+ type=str,
265
+ required=True,
266
+ help="The Team ID making the submission.",
267
+ metavar="ID",
268
+ )
269
+ @click.option(
270
+ "--service",
271
+ "service_id",
272
+ type=str,
273
+ required=True,
274
+ help="The InferenceService ID to submit.",
275
+ metavar="ID",
276
+ )
277
+ @click.option(
278
+ "--challenge",
279
+ "challenge_id",
280
+ type=str,
281
+ default="dc509a8c771b492b90c43012fde9a04f",
282
+ help="The Challenge ID to submit to.",
283
+ metavar="ID",
284
+ )
285
+ def submit(account: str, task_id: str, team_id: str, service_id: str, challenge_id: str) -> None:
286
+ dyffapi = Client()
287
+
288
+ challenge = dyffapi.challenges.get(challenge_id)
289
+ # print(challenge)
290
+ challengetask = challenge.tasks[task_id]
291
+
292
+ team = dyffapi.teams.get(team_id)
293
+
294
+ service = dyffapi.inferenceservices.get(service_id)
295
+
296
+ submission = dyffapi.challenges.submit(
297
+ challenge.id,
298
+ challengetask.id,
299
+ SubmissionCreateRequest(
300
+ account=account,
301
+ team=team.id,
302
+ submission=EntityIdentifier(kind="InferenceService", id=service.id),
303
+ ),
304
+ )
305
+ click.echo(submission.model_dump_json(indent=2))
306
+ click.echo(f"submission.id: \"{submission.id}\"")
307
+
308
+
309
+ if __name__ == "__main__":
310
+ cli(show_default=True)
delete_corrupted_files.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from PIL import Image
4
+
5
+ def is_image_corrupted(path):
6
+ """
7
+ Returns True if the image cannot be opened by cv2 or PIL.
8
+ """
9
+ # ---- Try reading with OpenCV ----
10
+ img = cv2.imread(path)
11
+ if img is None:
12
+ return True # cv2 failed to decode
13
+
14
+ # ---- Try reading with PIL ----
15
+ try:
16
+ with Image.open(path) as im:
17
+ im.verify() # Validate without fully loading
18
+ except Exception:
19
+ return True # PIL failed to open / verify
20
+
21
+ return False # Image is valid
22
+
23
+
24
+ def delete_corrupted_images(directory):
25
+ """
26
+ Deletes all unreadable / corrupted images in the directory.
27
+ """
28
+ valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
29
+
30
+ deleted_count = 0
31
+
32
+ for filename in os.listdir(directory):
33
+ path = os.path.join(directory, filename)
34
+
35
+ # Skip folders
36
+ if not os.path.isfile(path):
37
+ continue
38
+
39
+ # Check valid image extensions
40
+ ext = os.path.splitext(filename)[1].lower()
41
+ if ext not in valid_extensions:
42
+ continue
43
+
44
+ # Check corruption
45
+ if is_image_corrupted(path):
46
+ print(f"Deleting corrupted image: {filename}")
47
+ os.remove(path)
48
+ deleted_count += 1
49
+
50
+ print(f"\nDone. Deleted {deleted_count} corrupted images.")
51
+
52
+
53
+ # --------- RUN ----------
54
+ directory = "/media/NAS/DATASET/wacv_safe_challenge_2025_dataset/physical_fakes/PS-Battles/originals/" # <-- change this
55
+ delete_corrupted_images(directory)
main.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Main FastAPI application entry point.
3
+ """
4
+ from app.core.app import create_app
5
+
6
+ app = create_app()
makefile ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VENV ?= venv
2
+ PYTHON ?= $(VENV)/bin/python3
3
+ UVICORN ?= $(VENV)/bin/uvicorn
4
+ DOCKER ?= docker
5
+ IMAGE ?= safe-challenge-2025/example-submission
6
+ GID ?= $(shell id -g)
7
+ UID ?= $(shell id -u)
8
+
9
+ .PHONY: setup
10
+ setup: $(VENV)/requirements.cpu.txt
11
+
12
+ .PHONY: download
13
+ download:
14
+ bash scripts/model_download.bash
15
+
16
+ .PHONY: serve
17
+ serve:
18
+ $(UVICORN) main:app
19
+
20
+ .PHONY: docker-build
21
+ docker-build:
22
+ docker build -t $(IMAGE) .
23
+
24
+ .PHONY: docker-run
25
+ docker-run:
26
+ docker run --rm -it -p 8000:8000 $(IMAGE)
27
+
28
+ .PHONY: compile
29
+ compile:
30
+ uv pip compile --python-version 3.12 --upgrade -o requirements.torch.cpu.txt.tmp requirements.torch.cpu.in
31
+ echo "--index-url https://download.pytorch.org/whl/cpu" > requirements.torch.cpu.txt
32
+ grep -e '^torch' requirements.torch.cpu.txt.tmp >> requirements.torch.cpu.txt
33
+ uv pip compile --python-version 3.12 --upgrade -o requirements.cpu.txt requirements.cpu.in
34
+
35
+ .PHONY: compile-gpu
36
+ compile-gpu:
37
+ uv pip compile --python-version 3.12 --upgrade -o requirements.torch.gpu.txt requirements.torch.gpu.in
38
+
39
+ .PHONY: docker-build-gpu
40
+ docker-build-gpu:
41
+ docker build -t $(IMAGE)-gpu -f Dockerfile.gpu .
42
+
43
+ requirements.cpu.txt: requirements.in requirements.torch.cpu.txt | $(VENV)
44
+ uv pip compile --python-version 3.12 --upgrade -o requirements.cpu.txt requirements.cpu.in
45
+
46
+ requirements.torch.cpu.txt: requirements.torch.cpu.in | $(VENV)
47
+ uv pip compile --python-version 3.12 --upgrade -o requirements.torch.cpu.txt.tmp requirements.torch.cpu.in
48
+ echo "--index-url https://download.pytorch.org/whl/cpu" > requirements.torch.cpu.txt
49
+ cat requirements.torch.cpu.txt.tmp | grep '^torch' >> requirements.torch.cpu.txt
50
+
51
+ $(VENV)/requirements.cpu.txt: requirements.cpu.txt | $(VENV)
52
+ VIRTUAL_ENV=$(VENV) uv pip install -r requirements.cpu.txt
53
+ VIRTUAL_ENV=$(VENV) uv pip install -r requirements.torch.cpu.txt
54
+ cp -f requirements.cpu.txt $(VENV)/requirements.cpu.txt
55
+ cp -f requirements.torch.cpu.txt $(VENV)/requirements.torch.cpu.txt
56
+
57
+ $(VENV):
58
+ uv venv $(VENV)
mask.png ADDED
prompt.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ curl -X POST http://localhost:8000/predict \
4
+ -H "Content-Type: application/json" \
5
+ -d @"$1"
6
+
requirements copy.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ numpy==1.26.1
3
+ pandas==2.2.1
4
+ scipy==1.12.0
5
+
6
+ torch==2.4.0
7
+ torchvision==0.19.0
8
+ torchaudio==2.4.0
9
+
10
+ opencv-contrib-python==4.11.0.86
11
+ av==14.0.1
12
+ decord
13
+ kornia==0.7.4
14
+ timm
15
+ einops
16
+ beartype
17
+ lpips==0.1.4
18
+ imageio
19
+
20
+ audiomentations==0.36.0
21
+ librosa==0.10.2.post1
22
+ soundfile==0.12.1
23
+
24
+ transformers==4.49.0
25
+ datasets==3.5.1
26
+ tqdm
requirements.cpu.in ADDED
@@ -0,0 +1 @@
 
 
1
+ -r requirements.in
requirements.cpu.txt ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile --python-version 3.12 -o requirements.cpu.txt requirements.cpu.in
3
+ absl-py==2.3.1
4
+ # via dyff-client
5
+ annotated-doc==0.0.4
6
+ # via fastapi
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.12.0
10
+ # via
11
+ # httpx
12
+ # jupyter-server
13
+ # starlette
14
+ # watchfiles
15
+ argon2-cffi==25.1.0
16
+ # via jupyter-server
17
+ argon2-cffi-bindings==25.1.0
18
+ # via argon2-cffi
19
+ arrow==1.4.0
20
+ # via isoduration
21
+ asttokens==3.0.1
22
+ # via stack-data
23
+ async-lru==2.0.5
24
+ # via jupyterlab
25
+ attrs==25.4.0
26
+ # via
27
+ # jsonschema
28
+ # referencing
29
+ azure-core==1.36.0
30
+ # via dyff-client
31
+ babel==2.17.0
32
+ # via jupyterlab-server
33
+ beartype==0.22.8
34
+ # via -r requirements.in
35
+ beautifulsoup4==4.14.3
36
+ # via
37
+ # dyff-audit
38
+ # nbconvert
39
+ bleach==6.3.0
40
+ # via nbconvert
41
+ canonicaljson==2.0.0
42
+ # via dyff-schema
43
+ certifi==2025.11.12
44
+ # via
45
+ # httpcore
46
+ # httpx
47
+ # requests
48
+ cffi==2.0.0
49
+ # via argon2-cffi-bindings
50
+ charset-normalizer==3.4.4
51
+ # via requests
52
+ click==8.3.1
53
+ # via
54
+ # -r requirements.in
55
+ # uvicorn
56
+ comm==0.2.3
57
+ # via ipykernel
58
+ contourpy==1.3.3
59
+ # via matplotlib
60
+ cycler==0.12.1
61
+ # via matplotlib
62
+ debugpy==1.8.17
63
+ # via ipykernel
64
+ decorator==5.2.1
65
+ # via ipython
66
+ defusedxml==0.7.1
67
+ # via nbconvert
68
+ dnspython==2.8.0
69
+ # via email-validator
70
+ dyff-audit==0.16.7
71
+ # via -r requirements.in
72
+ dyff-client==0.23.9
73
+ # via
74
+ # -r requirements.in
75
+ # dyff-audit
76
+ dyff-schema==0.40.1
77
+ # via
78
+ # -r requirements.in
79
+ # dyff-audit
80
+ # dyff-client
81
+ einops==0.8.1
82
+ # via -r requirements.in
83
+ email-validator==2.3.0
84
+ # via dyff-schema
85
+ executing==2.2.1
86
+ # via stack-data
87
+ fastapi==0.124.1
88
+ # via -r requirements.in
89
+ fastjsonschema==2.21.2
90
+ # via nbformat
91
+ filelock==3.20.0
92
+ # via
93
+ # huggingface-hub
94
+ # torch
95
+ # transformers
96
+ fonttools==4.61.0
97
+ # via matplotlib
98
+ fqdn==1.5.1
99
+ # via jsonschema
100
+ fsspec==2025.12.0
101
+ # via
102
+ # huggingface-hub
103
+ # torch
104
+ google-i18n-address==3.1.1
105
+ # via dyff-schema
106
+ h11==0.16.0
107
+ # via
108
+ # httpcore
109
+ # uvicorn
110
+ hf-xet==1.2.0
111
+ # via huggingface-hub
112
+ httpcore==1.0.9
113
+ # via httpx
114
+ httptools==0.7.1
115
+ # via uvicorn
116
+ httpx==0.28.1
117
+ # via
118
+ # dyff-client
119
+ # jupyterlab
120
+ huggingface-hub==0.36.0
121
+ # via
122
+ # timm
123
+ # tokenizers
124
+ # transformers
125
+ hypothesis==6.148.7
126
+ # via
127
+ # dyff-schema
128
+ # hypothesis-jsonschema
129
+ hypothesis-jsonschema==0.23.1
130
+ # via dyff-schema
131
+ idna==3.11
132
+ # via
133
+ # anyio
134
+ # email-validator
135
+ # httpx
136
+ # jsonschema
137
+ # requests
138
+ ipykernel==7.1.0
139
+ # via jupyterlab
140
+ ipython==9.8.0
141
+ # via ipykernel
142
+ ipython-pygments-lexers==1.1.1
143
+ # via ipython
144
+ isodate==0.7.2
145
+ # via dyff-client
146
+ isoduration==20.11.0
147
+ # via jsonschema
148
+ jedi==0.19.2
149
+ # via ipython
150
+ jinja2==3.1.6
151
+ # via
152
+ # jupyter-server
153
+ # jupyterlab
154
+ # jupyterlab-server
155
+ # nbconvert
156
+ # torch
157
+ json5==0.12.1
158
+ # via jupyterlab-server
159
+ jsonpath-ng==1.7.0
160
+ # via
161
+ # dyff-client
162
+ # dyff-schema
163
+ jsonpointer==3.0.0
164
+ # via jsonschema
165
+ jsonschema==4.25.1
166
+ # via
167
+ # hypothesis-jsonschema
168
+ # jupyter-events
169
+ # jupyterlab-server
170
+ # nbformat
171
+ jsonschema-specifications==2025.9.1
172
+ # via jsonschema
173
+ jupyter-client==8.7.0
174
+ # via
175
+ # ipykernel
176
+ # jupyter-server
177
+ # nbclient
178
+ jupyter-core==5.9.1
179
+ # via
180
+ # ipykernel
181
+ # jupyter-client
182
+ # jupyter-server
183
+ # jupyterlab
184
+ # nbclient
185
+ # nbconvert
186
+ # nbformat
187
+ jupyter-events==0.12.0
188
+ # via jupyter-server
189
+ jupyter-lsp==2.3.0
190
+ # via jupyterlab
191
+ jupyter-server==2.17.0
192
+ # via
193
+ # jupyter-lsp
194
+ # jupyterlab
195
+ # jupyterlab-server
196
+ # notebook
197
+ # notebook-shim
198
+ jupyter-server-terminals==0.5.3
199
+ # via jupyter-server
200
+ jupyterlab==4.5.0
201
+ # via notebook
202
+ jupyterlab-pygments==0.3.0
203
+ # via nbconvert
204
+ jupyterlab-server==2.28.0
205
+ # via
206
+ # jupyterlab
207
+ # notebook
208
+ kiwisolver==1.4.9
209
+ # via matplotlib
210
+ lark==1.3.1
211
+ # via rfc3987-syntax
212
+ lpips==0.1.4
213
+ # via -r requirements.in
214
+ lxml==6.0.2
215
+ # via dyff-audit
216
+ markupsafe==3.0.3
217
+ # via
218
+ # jinja2
219
+ # nbconvert
220
+ matplotlib==3.10.7
221
+ # via -r requirements.in
222
+ matplotlib-inline==0.2.1
223
+ # via
224
+ # ipykernel
225
+ # ipython
226
+ mistune==3.1.4
227
+ # via nbconvert
228
+ mpmath==1.3.0
229
+ # via sympy
230
+ nbclient==0.10.2
231
+ # via nbconvert
232
+ nbconvert==7.16.6
233
+ # via
234
+ # dyff-audit
235
+ # jupyter-server
236
+ nbformat==5.10.4
237
+ # via
238
+ # dyff-audit
239
+ # jupyter-server
240
+ # nbclient
241
+ # nbconvert
242
+ nest-asyncio==1.6.0
243
+ # via ipykernel
244
+ networkx==3.6.1
245
+ # via torch
246
+ notebook==7.5.0
247
+ # via dyff-audit
248
+ notebook-shim==0.2.4
249
+ # via
250
+ # jupyterlab
251
+ # notebook
252
+ numpy==1.26.4
253
+ # via
254
+ # -r requirements.in
255
+ # contourpy
256
+ # dyff-audit
257
+ # dyff-client
258
+ # dyff-schema
259
+ # lpips
260
+ # matplotlib
261
+ # opencv-python
262
+ # pandas
263
+ # scipy
264
+ # torchvision
265
+ # transformers
266
+ nvidia-cublas-cu12==12.8.4.1
267
+ # via
268
+ # nvidia-cudnn-cu12
269
+ # nvidia-cusolver-cu12
270
+ # torch
271
+ nvidia-cuda-cupti-cu12==12.8.90
272
+ # via torch
273
+ nvidia-cuda-nvrtc-cu12==12.8.93
274
+ # via torch
275
+ nvidia-cuda-runtime-cu12==12.8.90
276
+ # via torch
277
+ nvidia-cudnn-cu12==9.10.2.21
278
+ # via torch
279
+ nvidia-cufft-cu12==11.3.3.83
280
+ # via torch
281
+ nvidia-cufile-cu12==1.13.1.3
282
+ # via torch
283
+ nvidia-curand-cu12==10.3.9.90
284
+ # via torch
285
+ nvidia-cusolver-cu12==11.7.3.90
286
+ # via torch
287
+ nvidia-cusparse-cu12==12.5.8.93
288
+ # via
289
+ # nvidia-cusolver-cu12
290
+ # torch
291
+ nvidia-cusparselt-cu12==0.7.1
292
+ # via torch
293
+ nvidia-nccl-cu12==2.27.5
294
+ # via torch
295
+ nvidia-nvjitlink-cu12==12.8.93
296
+ # via
297
+ # nvidia-cufft-cu12
298
+ # nvidia-cusolver-cu12
299
+ # nvidia-cusparse-cu12
300
+ # torch
301
+ nvidia-nvshmem-cu12==3.3.20
302
+ # via torch
303
+ nvidia-nvtx-cu12==12.8.90
304
+ # via torch
305
+ opencv-python==4.11.0.86
306
+ # via -r requirements.in
307
+ packaging==25.0
308
+ # via
309
+ # huggingface-hub
310
+ # ipykernel
311
+ # jupyter-events
312
+ # jupyter-server
313
+ # jupyterlab
314
+ # jupyterlab-server
315
+ # matplotlib
316
+ # nbconvert
317
+ # transformers
318
+ pandas==2.3.3
319
+ # via
320
+ # -r requirements.in
321
+ # dyff-audit
322
+ # dyff-client
323
+ pandocfilters==1.5.1
324
+ # via nbconvert
325
+ parso==0.8.5
326
+ # via jedi
327
+ pexpect==4.9.0
328
+ # via ipython
329
+ pillow==12.0.0
330
+ # via
331
+ # -r requirements.in
332
+ # matplotlib
333
+ # torchvision
334
+ platformdirs==4.5.1
335
+ # via jupyter-core
336
+ ply==3.11
337
+ # via jsonpath-ng
338
+ prometheus-client==0.23.1
339
+ # via jupyter-server
340
+ prompt-toolkit==3.0.52
341
+ # via ipython
342
+ psutil==7.1.3
343
+ # via ipykernel
344
+ ptyprocess==0.7.0
345
+ # via
346
+ # pexpect
347
+ # terminado
348
+ pure-eval==0.2.3
349
+ # via stack-data
350
+ pyarrow==22.0.0
351
+ # via
352
+ # -r requirements.in
353
+ # dyff-audit
354
+ # dyff-client
355
+ # dyff-schema
356
+ pycparser==2.23
357
+ # via cffi
358
+ pydantic==2.12.5
359
+ # via
360
+ # -r requirements.in
361
+ # dyff-audit
362
+ # dyff-client
363
+ # dyff-schema
364
+ # fastapi
365
+ # pydantic-settings
366
+ pydantic-core==2.41.5
367
+ # via pydantic
368
+ pydantic-settings==2.12.0
369
+ # via -r requirements.in
370
+ pygments==2.19.2
371
+ # via
372
+ # ipython
373
+ # ipython-pygments-lexers
374
+ # nbconvert
375
+ pyparsing==3.2.5
376
+ # via matplotlib
377
+ python-dateutil==2.9.0.post0
378
+ # via
379
+ # arrow
380
+ # jupyter-client
381
+ # matplotlib
382
+ # pandas
383
+ python-dotenv==1.2.1
384
+ # via
385
+ # -r requirements.in
386
+ # pydantic-settings
387
+ # uvicorn
388
+ python-json-logger==4.0.0
389
+ # via jupyter-events
390
+ python-multipart==0.0.20
391
+ # via -r requirements.in
392
+ pytz==2025.2
393
+ # via pandas
394
+ pyyaml==6.0.3
395
+ # via
396
+ # huggingface-hub
397
+ # jupyter-events
398
+ # timm
399
+ # transformers
400
+ # uvicorn
401
+ pyzmq==27.1.0
402
+ # via
403
+ # ipykernel
404
+ # jupyter-client
405
+ # jupyter-server
406
+ referencing==0.37.0
407
+ # via
408
+ # jsonschema
409
+ # jsonschema-specifications
410
+ # jupyter-events
411
+ regex==2025.11.3
412
+ # via transformers
413
+ requests==2.32.5
414
+ # via
415
+ # -r requirements.in
416
+ # azure-core
417
+ # google-i18n-address
418
+ # huggingface-hub
419
+ # jupyterlab-server
420
+ # transformers
421
+ rfc3339-validator==0.1.4
422
+ # via
423
+ # jsonschema
424
+ # jupyter-events
425
+ rfc3986-validator==0.1.1
426
+ # via
427
+ # jsonschema
428
+ # jupyter-events
429
+ rfc3987-syntax==1.1.0
430
+ # via jsonschema
431
+ rpds-py==0.30.0
432
+ # via
433
+ # jsonschema
434
+ # referencing
435
+ ruamel-yaml==0.18.16
436
+ # via dyff-audit
437
+ ruamel-yaml-clib==0.2.15
438
+ # via ruamel-yaml
439
+ safetensors==0.7.0
440
+ # via
441
+ # timm
442
+ # transformers
443
+ scipy==1.16.3
444
+ # via lpips
445
+ send2trash==1.8.3
446
+ # via jupyter-server
447
+ setuptools==80.9.0
448
+ # via
449
+ # jupyterlab
450
+ # torch
451
+ six==1.17.0
452
+ # via
453
+ # python-dateutil
454
+ # rfc3339-validator
455
+ sortedcontainers==2.4.0
456
+ # via hypothesis
457
+ soupsieve==2.8
458
+ # via beautifulsoup4
459
+ stack-data==0.6.3
460
+ # via ipython
461
+ starlette==0.50.0
462
+ # via fastapi
463
+ sympy==1.14.0
464
+ # via torch
465
+ terminado==0.18.1
466
+ # via
467
+ # jupyter-server
468
+ # jupyter-server-terminals
469
+ timm==1.0.22
470
+ # via -r requirements.in
471
+ tinycss2==1.4.0
472
+ # via bleach
473
+ tokenizers==0.22.1
474
+ # via transformers
475
+ torch==2.9.1
476
+ # via
477
+ # lpips
478
+ # timm
479
+ # torchvision
480
+ torchvision==0.24.1
481
+ # via
482
+ # lpips
483
+ # timm
484
+ tornado==6.5.2
485
+ # via
486
+ # ipykernel
487
+ # jupyter-client
488
+ # jupyter-server
489
+ # jupyterlab
490
+ # notebook
491
+ # terminado
492
+ tqdm==4.67.1
493
+ # via
494
+ # dyff-client
495
+ # huggingface-hub
496
+ # lpips
497
+ # transformers
498
+ traitlets==5.14.3
499
+ # via
500
+ # ipykernel
501
+ # ipython
502
+ # jupyter-client
503
+ # jupyter-core
504
+ # jupyter-events
505
+ # jupyter-server
506
+ # jupyterlab
507
+ # matplotlib-inline
508
+ # nbclient
509
+ # nbconvert
510
+ # nbformat
511
+ transformers==4.57.3
512
+ # via -r requirements.in
513
+ triton
514
+ # via torch
515
+ typing
516
+ # via -r requirements.in
517
+ typing-extensions
518
+ # via
519
+ # anyio
520
+ # azure-core
521
+ # beautifulsoup4
522
+ # fastapi
523
+ # huggingface-hub
524
+ # pydantic
525
+ # pydantic-core
526
+ # referencing
527
+ # starlette
528
+ # torch
529
+ # typing-inspection
530
+ typing-inspection
531
+ # via
532
+ # pydantic
533
+ # pydantic-settings
534
+ tzdata
535
+ # via
536
+ # arrow
537
+ # pandas
538
+ uri-template
539
+ # via jsonschema
540
+ urllib3
541
+ # via requests
542
+ uvicorn
543
+ # via -r requirements.in
544
+ uvloop
545
+ # via uvicorn
546
+ watchfiles
547
+ # via uvicorn
548
+ wcwidth==0.2.14
549
+ # via prompt-toolkit
550
+ webcolors==25.10.0
551
+ # via jsonschema
552
+ webencodings==0.5.1
553
+ # via
554
+ # bleach
555
+ # tinycss2
556
+ websocket-client==1.9.0
557
+ # via jupyter-server
558
+ websockets==15.0.1
559
+ # via
560
+ # dyff-client
561
+ # uvicorn
requirements.in ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dyff-audit
2
+ dyff-client
3
+ dyff-schema>=0.39.1
4
+
5
+ click
6
+
7
+ # Web framework
8
+ fastapi
9
+ uvicorn[standard]
10
+
11
+ # Configuration management
12
+ pydantic>=2.0.0,<3.0.0
13
+ pydantic-settings
14
+ python-dotenv
15
+
16
+ # File upload handling
17
+ python-multipart
18
+
19
+ # ML/AI dependencies (newly added)
20
+ transformers>=4.35.0
21
+ pillow>=10.0.0
22
+
23
+ # Dataset generation and testing
24
+ pyarrow>=14.0.0
25
+ numpy
26
+ pandas
27
+ requests>=2.25.0
28
+ opencv-python
29
+ matplotlib
30
+ typing
31
+ einops
32
+ timm
33
+ beartype
34
+ lpips
requirements.torch.cpu.in ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ --index-url https://download.pytorch.org/whl/cpu
2
+ torch==2.9.1+cpu
3
+ torchvision==0.24.1+cpu
requirements.torch.cpu.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #--index-url https://download.pytorch.org/whl/cu126
2
+ #torch==2.9.1
3
+ #torchvision==0.24.1
4
+ #torchaudio==2.9.0
5
+
6
+ --index-url https://download.pytorch.org/whl/nightly/cu121
7
+ torch
8
+ torchvision
9
+ torchaudio
requirements.torch.gpu.in ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ torchvision
requirements.torch.gpu.txt ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile --python-version 3.12 -o requirements.torch.gpu.txt requirements.torch.gpu.in
3
+ filelock==3.20.0
4
+ # via torch
5
+ fsspec==2025.12.0
6
+ # via torch
7
+ jinja2==3.1.6
8
+ # via torch
9
+ markupsafe==3.0.3
10
+ # via jinja2
11
+ mpmath==1.3.0
12
+ # via sympy
13
+ networkx==3.6.1
14
+ # via torch
15
+ numpy==2.3.5
16
+ # via torchvision
17
+ nvidia-cublas-cu12==12.8.4.1
18
+ # via
19
+ # nvidia-cudnn-cu12
20
+ # nvidia-cusolver-cu12
21
+ # torch
22
+ nvidia-cuda-cupti-cu12==12.8.90
23
+ # via torch
24
+ nvidia-cuda-nvrtc-cu12==12.8.93
25
+ # via torch
26
+ nvidia-cuda-runtime-cu12==12.8.90
27
+ # via torch
28
+ nvidia-cudnn-cu12==9.10.2.21
29
+ # via torch
30
+ nvidia-cufft-cu12==11.3.3.83
31
+ # via torch
32
+ nvidia-cufile-cu12==1.13.1.3
33
+ # via torch
34
+ nvidia-curand-cu12==10.3.9.90
35
+ # via torch
36
+ nvidia-cusolver-cu12==11.7.3.90
37
+ # via torch
38
+ nvidia-cusparse-cu12==12.5.8.93
39
+ # via
40
+ # nvidia-cusolver-cu12
41
+ # torch
42
+ nvidia-cusparselt-cu12==0.7.1
43
+ # via torch
44
+ nvidia-nccl-cu12==2.27.5
45
+ # via torch
46
+ nvidia-nvjitlink-cu12==12.8.93
47
+ # via
48
+ # nvidia-cufft-cu12
49
+ # nvidia-cusolver-cu12
50
+ # nvidia-cusparse-cu12
51
+ # torch
52
+ nvidia-nvshmem-cu12==3.3.20
53
+ # via torch
54
+ nvidia-nvtx-cu12==12.8.90
55
+ # via torch
56
+ pillow==12.0.0
57
+ # via torchvision
58
+ setuptools==80.9.0
59
+ # via torch
60
+ sympy==1.14.0
61
+ # via torch
62
+ torch==2.9.1
63
+ # via
64
+ # -r requirements.torch.gpu.in
65
+ # torchvision
66
+ torchvision==0.24.1
67
+ # via -r requirements.torch.gpu.in
68
+ triton==3.5.1
69
+ # via torch
70
+ typing-extensions==4.15.0
71
+ # via torch
requirements.txt ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile --python-version 3.12 -o requirements.txt requirements.in
3
+ absl-py==2.3.1
4
+ # via dyff-client
5
+ annotated-types==0.7.0
6
+ # via pydantic
7
+ anyio==3.7.1
8
+ # via
9
+ # fastapi
10
+ # httpx
11
+ # jupyter-server
12
+ # starlette
13
+ # watchfiles
14
+ argon2-cffi==25.1.0
15
+ # via jupyter-server
16
+ argon2-cffi-bindings==25.1.0
17
+ # via argon2-cffi
18
+ arrow==1.4.0
19
+ # via isoduration
20
+ asttokens==3.0.1
21
+ # via stack-data
22
+ async-lru==2.0.5
23
+ # via jupyterlab
24
+ attrs==25.4.0
25
+ # via
26
+ # jsonschema
27
+ # referencing
28
+ azure-core==1.36.0
29
+ # via dyff-client
30
+ babel==2.17.0
31
+ # via jupyterlab-server
32
+ beautifulsoup4==4.14.2
33
+ # via
34
+ # dyff-audit
35
+ # nbconvert
36
+ bleach==6.3.0
37
+ # via nbconvert
38
+ canonicaljson==2.0.0
39
+ # via dyff-schema
40
+ certifi==2025.11.12
41
+ # via
42
+ # httpcore
43
+ # httpx
44
+ # requests
45
+ cffi==2.0.0
46
+ # via argon2-cffi-bindings
47
+ charset-normalizer==3.4.4
48
+ # via requests
49
+ click==8.3.1
50
+ # via
51
+ # -r requirements.in
52
+ # uvicorn
53
+ comm==0.2.3
54
+ # via ipykernel
55
+ debugpy==1.8.17
56
+ # via ipykernel
57
+ decorator==5.2.1
58
+ # via ipython
59
+ defusedxml==0.7.1
60
+ # via nbconvert
61
+ dnspython==2.8.0
62
+ # via email-validator
63
+ dyff-audit==0.16.1
64
+ # via -r requirements.in
65
+ dyff-client==0.23.5
66
+ # via
67
+ # -r requirements.in
68
+ # dyff-audit
69
+ dyff-schema==0.39.1
70
+ # via
71
+ # -r requirements.in
72
+ # dyff-audit
73
+ # dyff-client
74
+ email-validator==2.3.0
75
+ # via dyff-schema
76
+ executing==2.2.1
77
+ # via stack-data
78
+ fastapi==0.104.1
79
+ # via -r requirements.in
80
+ fastjsonschema==2.21.2
81
+ # via nbformat
82
+ filelock==3.20.0
83
+ # via
84
+ # huggingface-hub
85
+ # torch
86
+ # transformers
87
+ fqdn==1.5.1
88
+ # via jsonschema
89
+ fsspec==2025.10.0
90
+ # via
91
+ # huggingface-hub
92
+ # torch
93
+ google-i18n-address==3.1.1
94
+ # via dyff-schema
95
+ h11==0.16.0
96
+ # via
97
+ # httpcore
98
+ # uvicorn
99
+ hf-xet==1.2.0
100
+ # via huggingface-hub
101
+ httpcore==1.0.9
102
+ # via httpx
103
+ httptools==0.7.1
104
+ # via uvicorn
105
+ httpx==0.28.1
106
+ # via
107
+ # dyff-client
108
+ # jupyterlab
109
+ huggingface-hub==0.36.0
110
+ # via
111
+ # tokenizers
112
+ # transformers
113
+ hypothesis==6.148.1
114
+ # via
115
+ # dyff-schema
116
+ # hypothesis-jsonschema
117
+ hypothesis-jsonschema==0.23.1
118
+ # via dyff-schema
119
+ idna==3.11
120
+ # via
121
+ # anyio
122
+ # email-validator
123
+ # httpx
124
+ # jsonschema
125
+ # requests
126
+ ipykernel==7.1.0
127
+ # via jupyterlab
128
+ ipython==9.7.0
129
+ # via ipykernel
130
+ ipython-pygments-lexers==1.1.1
131
+ # via ipython
132
+ isodate==0.7.2
133
+ # via dyff-client
134
+ isoduration==20.11.0
135
+ # via jsonschema
136
+ jedi==0.19.2
137
+ # via ipython
138
+ jinja2==3.1.6
139
+ # via
140
+ # jupyter-server
141
+ # jupyterlab
142
+ # jupyterlab-server
143
+ # nbconvert
144
+ # torch
145
+ json5==0.12.1
146
+ # via jupyterlab-server
147
+ jsonpath-ng==1.7.0
148
+ # via
149
+ # dyff-client
150
+ # dyff-schema
151
+ jsonpointer==3.0.0
152
+ # via jsonschema
153
+ jsonschema==4.25.1
154
+ # via
155
+ # hypothesis-jsonschema
156
+ # jupyter-events
157
+ # jupyterlab-server
158
+ # nbformat
159
+ jsonschema-specifications==2025.9.1
160
+ # via jsonschema
161
+ jupyter-client==8.6.3
162
+ # via
163
+ # ipykernel
164
+ # jupyter-server
165
+ # nbclient
166
+ jupyter-core==5.9.1
167
+ # via
168
+ # ipykernel
169
+ # jupyter-client
170
+ # jupyter-server
171
+ # jupyterlab
172
+ # nbclient
173
+ # nbconvert
174
+ # nbformat
175
+ jupyter-events==0.12.0
176
+ # via jupyter-server
177
+ jupyter-lsp==2.3.0
178
+ # via jupyterlab
179
+ jupyter-server==2.17.0
180
+ # via
181
+ # jupyter-lsp
182
+ # jupyterlab
183
+ # jupyterlab-server
184
+ # notebook
185
+ # notebook-shim
186
+ jupyter-server-terminals==0.5.3
187
+ # via jupyter-server
188
+ jupyterlab==4.4.10
189
+ # via notebook
190
+ jupyterlab-pygments==0.3.0
191
+ # via nbconvert
192
+ jupyterlab-server==2.28.0
193
+ # via
194
+ # jupyterlab
195
+ # notebook
196
+ lark==1.3.1
197
+ # via rfc3987-syntax
198
+ lxml==6.0.2
199
+ # via dyff-audit
200
+ markupsafe==3.0.3
201
+ # via
202
+ # jinja2
203
+ # nbconvert
204
+ matplotlib-inline==0.2.1
205
+ # via
206
+ # ipykernel
207
+ # ipython
208
+ mistune==3.1.4
209
+ # via nbconvert
210
+ mpmath==1.3.0
211
+ # via sympy
212
+ nbclient==0.10.2
213
+ # via nbconvert
214
+ nbconvert==7.16.6
215
+ # via
216
+ # dyff-audit
217
+ # jupyter-server
218
+ nbformat==5.10.4
219
+ # via
220
+ # dyff-audit
221
+ # jupyter-server
222
+ # nbclient
223
+ # nbconvert
224
+ nest-asyncio==1.6.0
225
+ # via ipykernel
226
+ networkx==3.5
227
+ # via torch
228
+ notebook==7.4.7
229
+ # via dyff-audit
230
+ notebook-shim==0.2.4
231
+ # via
232
+ # jupyterlab
233
+ # notebook
234
+ numpy==1.26.4
235
+ # via
236
+ # -r requirements.in
237
+ # dyff-audit
238
+ # dyff-client
239
+ # dyff-schema
240
+ # pandas
241
+ # transformers
242
+ nvidia-cublas-cu12==12.8.4.1
243
+ # via
244
+ # nvidia-cudnn-cu12
245
+ # nvidia-cusolver-cu12
246
+ # torch
247
+ nvidia-cuda-cupti-cu12==12.8.90
248
+ # via torch
249
+ nvidia-cuda-nvrtc-cu12==12.8.93
250
+ # via torch
251
+ nvidia-cuda-runtime-cu12==12.8.90
252
+ # via torch
253
+ nvidia-cudnn-cu12==9.10.2.21
254
+ # via torch
255
+ nvidia-cufft-cu12==11.3.3.83
256
+ # via torch
257
+ nvidia-cufile-cu12==1.13.1.3
258
+ # via torch
259
+ nvidia-curand-cu12==10.3.9.90
260
+ # via torch
261
+ nvidia-cusolver-cu12==11.7.3.90
262
+ # via torch
263
+ nvidia-cusparse-cu12==12.5.8.93
264
+ # via
265
+ # nvidia-cusolver-cu12
266
+ # torch
267
+ nvidia-cusparselt-cu12==0.7.1
268
+ # via torch
269
+ nvidia-nccl-cu12==2.27.5
270
+ # via torch
271
+ nvidia-nvjitlink-cu12==12.8.93
272
+ # via
273
+ # nvidia-cufft-cu12
274
+ # nvidia-cusolver-cu12
275
+ # nvidia-cusparse-cu12
276
+ # torch
277
+ nvidia-nvshmem-cu12==3.3.20
278
+ # via torch
279
+ nvidia-nvtx-cu12==12.8.90
280
+ # via torch
281
+ packaging==25.0
282
+ # via
283
+ # huggingface-hub
284
+ # ipykernel
285
+ # jupyter-events
286
+ # jupyter-server
287
+ # jupyterlab
288
+ # jupyterlab-server
289
+ # nbconvert
290
+ # transformers
291
+ pandas==2.3.3
292
+ # via
293
+ # -r requirements.in
294
+ # dyff-audit
295
+ # dyff-client
296
+ pandocfilters==1.5.1
297
+ # via nbconvert
298
+ parso==0.8.5
299
+ # via jedi
300
+ pexpect==4.9.0
301
+ # via ipython
302
+ pillow==12.0.0
303
+ # via -r requirements.in
304
+ platformdirs==4.5.0
305
+ # via jupyter-core
306
+ ply==3.11
307
+ # via jsonpath-ng
308
+ prometheus-client==0.23.1
309
+ # via jupyter-server
310
+ prompt-toolkit==3.0.52
311
+ # via ipython
312
+ psutil==7.1.3
313
+ # via ipykernel
314
+ ptyprocess==0.7.0
315
+ # via
316
+ # pexpect
317
+ # terminado
318
+ pure-eval==0.2.3
319
+ # via stack-data
320
+ pyarrow==22.0.0
321
+ # via
322
+ # -r requirements.in
323
+ # dyff-audit
324
+ # dyff-client
325
+ # dyff-schema
326
+ pycparser==2.23
327
+ # via cffi
328
+ pydantic==2.5.0
329
+ # via
330
+ # -r requirements.in
331
+ # dyff-audit
332
+ # dyff-client
333
+ # dyff-schema
334
+ # fastapi
335
+ # pydantic-settings
336
+ pydantic-core==2.14.1
337
+ # via pydantic
338
+ pydantic-settings==2.0.3
339
+ # via -r requirements.in
340
+ pygments==2.19.2
341
+ # via
342
+ # ipython
343
+ # ipython-pygments-lexers
344
+ # nbconvert
345
+ python-dateutil==2.9.0.post0
346
+ # via
347
+ # arrow
348
+ # jupyter-client
349
+ # pandas
350
+ python-dotenv==0.21.0
351
+ # via
352
+ # -r requirements.in
353
+ # pydantic-settings
354
+ # uvicorn
355
+ python-json-logger==4.0.0
356
+ # via jupyter-events
357
+ python-multipart==0.0.6
358
+ # via -r requirements.in
359
+ pytz==2025.2
360
+ # via pandas
361
+ pyyaml==6.0.3
362
+ # via
363
+ # huggingface-hub
364
+ # jupyter-events
365
+ # transformers
366
+ # uvicorn
367
+ pyzmq==27.1.0
368
+ # via
369
+ # ipykernel
370
+ # jupyter-client
371
+ # jupyter-server
372
+ referencing==0.37.0
373
+ # via
374
+ # jsonschema
375
+ # jsonschema-specifications
376
+ # jupyter-events
377
+ regex==2025.11.3
378
+ # via transformers
379
+ requests==2.32.5
380
+ # via
381
+ # -r requirements.in
382
+ # azure-core
383
+ # google-i18n-address
384
+ # huggingface-hub
385
+ # jupyterlab-server
386
+ # transformers
387
+ rfc3339-validator==0.1.4
388
+ # via
389
+ # jsonschema
390
+ # jupyter-events
391
+ rfc3986-validator==0.1.1
392
+ # via
393
+ # jsonschema
394
+ # jupyter-events
395
+ rfc3987-syntax==1.1.0
396
+ # via jsonschema
397
+ rpds-py==0.29.0
398
+ # via
399
+ # jsonschema
400
+ # referencing
401
+ ruamel-yaml==0.18.16
402
+ # via dyff-audit
403
+ ruamel-yaml-clib==0.2.15
404
+ # via ruamel-yaml
405
+ safetensors==0.6.2
406
+ # via transformers
407
+ send2trash==1.8.3
408
+ # via jupyter-server
409
+ setuptools==80.9.0
410
+ # via
411
+ # jupyterlab
412
+ # torch
413
+ six==1.17.0
414
+ # via
415
+ # python-dateutil
416
+ # rfc3339-validator
417
+ sniffio==1.3.1
418
+ # via anyio
419
+ sortedcontainers==2.4.0
420
+ # via hypothesis
421
+ soupsieve==2.8
422
+ # via beautifulsoup4
423
+ stack-data==0.6.3
424
+ # via ipython
425
+ starlette==0.27.0
426
+ # via fastapi
427
+ sympy==1.14.0
428
+ # via torch
429
+ terminado==0.18.1
430
+ # via
431
+ # jupyter-server
432
+ # jupyter-server-terminals
433
+ tinycss2==1.4.0
434
+ # via bleach
435
+ tokenizers==0.22.1
436
+ # via transformers
437
+ torch==2.9.1
438
+ # via -r requirements.in
439
+ tornado==6.5.2
440
+ # via
441
+ # ipykernel
442
+ # jupyter-client
443
+ # jupyter-server
444
+ # jupyterlab
445
+ # notebook
446
+ # terminado
447
+ tqdm==4.67.1
448
+ # via
449
+ # dyff-client
450
+ # huggingface-hub
451
+ # transformers
452
+ traitlets==5.14.3
453
+ # via
454
+ # ipykernel
455
+ # ipython
456
+ # jupyter-client
457
+ # jupyter-core
458
+ # jupyter-events
459
+ # jupyter-server
460
+ # jupyterlab
461
+ # matplotlib-inline
462
+ # nbclient
463
+ # nbconvert
464
+ # nbformat
465
+ transformers==4.57.1
466
+ # via -r requirements.in
467
+ triton==3.5.1
468
+ # via torch
469
+ typing-extensions==4.15.0
470
+ # via
471
+ # azure-core
472
+ # beautifulsoup4
473
+ # fastapi
474
+ # huggingface-hub
475
+ # pydantic
476
+ # pydantic-core
477
+ # referencing
478
+ # torch
479
+ tzdata==2025.2
480
+ # via
481
+ # arrow
482
+ # pandas
483
+ uri-template==1.3.0
484
+ # via jsonschema
485
+ urllib3==2.5.0
486
+ # via requests
487
+ uvicorn==0.24.0
488
+ # via -r requirements.in
489
+ uvloop==0.22.1
490
+ # via uvicorn
491
+ watchfiles==1.1.1
492
+ # via uvicorn
493
+ wcwidth==0.2.14
494
+ # via prompt-toolkit
495
+ webcolors==25.10.0
496
+ # via jsonschema
497
+ webencodings==0.5.1
498
+ # via
499
+ # bleach
500
+ # tinycss2
501
+ websocket-client==1.9.0
502
+ # via jupyter-server
503
+ websockets==15.0.1
504
+ # via
505
+ # dyff-client
506
+ # uvicorn
507
+
508
+ numpy
509
+ pandas
510
+ scipy==1.12.0
511
+
512
+
513
+ opencv-contrib-python==4.11.0.86
514
+ av==14.0.1
515
+ decord
516
+ kornia==0.7.4
517
+ timm
518
+ einops
519
+ beartype
520
+ lpips==0.1.4
521
+ imageio
522
+
523
+ audiomentations==0.36.0
524
+ librosa==0.10.2.post1
525
+ soundfile==0.12.1
526
+
527
+ transformers
528
+ datasets
529
+ tqdm
530
+ opencv-python
response.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"logprobs":[-0.859380304813385,-1.2701971530914307,-2.1918208599090576,-1.69235098361969],"localizationMask":{"mediaType":"image/png","data":"iVBORw0KGgoAAAANSUhEUgAAA8AAAAKDAQAAAAD9Fl5AAAAAu0lEQVR4nO3NsREAMAgDMWD/nZMVKEwn1T5/FQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMCl3g5f+HC24TRhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAj70gwKsTlmdBwAAAABJRU5ErkJggg=="}}
scripts/generate_test_datasets.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ PyArrow Dataset Generator for ML Inference Service
4
+
5
+ Generates test datasets for academic challenges and model validation.
6
+ Creates 100 PyArrow datasets with various image types and test scenarios.
7
+ """
8
+
9
+ import base64
10
+ import json
11
+ import random
12
+ from pathlib import Path
13
+ from typing import Dict, List, Any, Tuple
14
+ import io
15
+
16
+ import numpy as np
17
+ import pyarrow as pa
18
+ import pyarrow.parquet as pq
19
+ from PIL import Image, ImageDraw, ImageFont
20
+
21
+
22
+ class TestDatasetGenerator:
23
+ def __init__(self, output_dir: str = "test_datasets"):
24
+ self.output_dir = Path(output_dir)
25
+ self.output_dir.mkdir(exist_ok=True)
26
+
27
+ # ImageNet class labels (sample for testing)
28
+ self.imagenet_labels = [
29
+ "tench", "goldfish", "great_white_shark", "tiger_shark", "hammerhead",
30
+ "electric_ray", "stingray", "cock", "hen", "ostrich", "brambling",
31
+ "goldfinch", "house_finch", "junco", "indigo_bunting", "robin",
32
+ "bulbul", "jay", "magpie", "chickadee", "water_ouzel", "kite",
33
+ "bald_eagle", "vulture", "great_grey_owl", "European_fire_salamander",
34
+ "common_newt", "eft", "spotted_salamander", "axolotl", "bullfrog",
35
+ "tree_frog", "tailed_frog", "loggerhead", "leatherback_turtle",
36
+ "mud_turtle", "terrapin", "box_turtle", "banded_gecko", "common_iguana",
37
+ "American_chameleon", "whiptail", "agama", "frilled_lizard", "alligator_lizard",
38
+ "Gila_monster", "green_lizard", "African_chameleon", "Komodo_dragon",
39
+ "African_crocodile", "American_alligator", "triceratops", "thunder_snake"
40
+ ]
41
+
42
+ def create_synthetic_image(self, width: int = 224, height: int = 224,
43
+ image_type: str = "random") -> Image.Image:
44
+ """Create synthetic images for testing."""
45
+ if image_type == "random":
46
+ # Random noise image
47
+ array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
48
+ return Image.fromarray(array)
49
+
50
+ elif image_type == "geometric":
51
+ # Geometric patterns
52
+ img = Image.new('RGB', (width, height), color='white')
53
+ draw = ImageDraw.Draw(img)
54
+
55
+ # Draw random shapes
56
+ for _ in range(random.randint(3, 8)):
57
+ color = tuple(random.randint(0, 255) for _ in range(3))
58
+ shape_type = random.choice(['rectangle', 'ellipse'])
59
+ x1, y1 = random.randint(0, width//2), random.randint(0, height//2)
60
+ x2, y2 = x1 + random.randint(20, width//2), y1 + random.randint(20, height//2)
61
+
62
+ if shape_type == 'rectangle':
63
+ draw.rectangle([x1, y1, x2, y2], fill=color)
64
+ else:
65
+ draw.ellipse([x1, y1, x2, y2], fill=color)
66
+
67
+ return img
68
+
69
+ elif image_type == "gradient":
70
+ array = np.zeros((height, width, 3), dtype=np.uint8)
71
+ for i in range(height):
72
+ for j in range(width):
73
+ array[i, j] = [i * 255 // height, j * 255 // width, (i + j) * 255 // (height + width)]
74
+ return Image.fromarray(array)
75
+
76
+ elif image_type == "text":
77
+ img = Image.new('RGB', (width, height), color='white')
78
+ draw = ImageDraw.Draw(img)
79
+
80
+ try:
81
+ font = ImageFont.load_default()
82
+ except:
83
+ font = None
84
+
85
+ text = f"Test Image {random.randint(1, 1000)}"
86
+ draw.text((width//4, height//2), text, fill='black', font=font)
87
+ return img
88
+
89
+ else:
90
+ color = tuple(random.randint(0, 255) for _ in range(3))
91
+ return Image.new('RGB', (width, height), color=color)
92
+
93
+ def image_to_base64(self, image: Image.Image, format: str = "JPEG") -> str:
94
+ """Convert PIL image to base64 string."""
95
+ buffer = io.BytesIO()
96
+ image.save(buffer, format=format)
97
+ image_bytes = buffer.getvalue()
98
+ return base64.b64encode(image_bytes).decode('utf-8')
99
+
100
+ def create_api_request(self, image_b64: str, media_type: str = "image/jpeg") -> Dict[str, Any]:
101
+ """Create API request structure matching your service."""
102
+ return {
103
+ "image": {
104
+ "mediaType": media_type,
105
+ "data": image_b64
106
+ }
107
+ }
108
+
109
+ def create_expected_response(self, model_name: str = "microsoft/resnet-18",
110
+ media_type: str = "image/jpeg") -> Dict[str, Any]:
111
+ """Create expected response structure."""
112
+ prediction = random.choice(self.imagenet_labels)
113
+ return {
114
+ "prediction": prediction,
115
+ "confidence": round(random.uniform(0.3, 0.99), 4),
116
+ "predicted_label": random.randint(0, len(self.imagenet_labels) - 1),
117
+ "model": model_name,
118
+ "mediaType": media_type
119
+ }
120
+
121
+ def generate_standard_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
122
+ """Generate standard test cases with normal images."""
123
+ datasets = []
124
+
125
+ for i in range(count):
126
+ image_types = ["random", "geometric", "gradient", "text", "solid"]
127
+ sizes = [(224, 224), (256, 256), (299, 299), (384, 384)]
128
+ formats = [("JPEG", "image/jpeg"), ("PNG", "image/png")]
129
+
130
+ records = []
131
+ for j in range(random.randint(5, 20)): # 5-20 images per dataset
132
+ img_type = random.choice(image_types)
133
+ size = random.choice(sizes)
134
+ format_info = random.choice(formats)
135
+
136
+ image = self.create_synthetic_image(size[0], size[1], img_type)
137
+ image_b64 = self.image_to_base64(image, format_info[0])
138
+
139
+ api_request = self.create_api_request(image_b64, format_info[1])
140
+ expected_response = self.create_expected_response()
141
+
142
+ record = {
143
+ "dataset_id": f"standard_{i:03d}",
144
+ "image_id": f"img_{j:03d}",
145
+ "image_type": img_type,
146
+ "image_size": f"{size[0]}x{size[1]}",
147
+ "format": format_info[0],
148
+ "media_type": format_info[1],
149
+ "api_request": json.dumps(api_request),
150
+ "expected_response": json.dumps(expected_response),
151
+ "test_category": "standard",
152
+ "difficulty": "normal"
153
+ }
154
+ records.append(record)
155
+
156
+ datasets.append({
157
+ "name": f"standard_test_{i:03d}",
158
+ "category": "standard",
159
+ "description": f"Standard test dataset {i+1} with {len(records)} images",
160
+ "records": records
161
+ })
162
+
163
+ return datasets
164
+
165
+ def generate_edge_case_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
166
+ """Generate datasets for edge case scenarios."""
167
+ datasets = []
168
+
169
+ for i in range(count):
170
+ records = []
171
+ edge_cases = [
172
+ {"type": "tiny", "size": (32, 32), "difficulty": "high"},
173
+ {"type": "huge", "size": (2048, 2048), "difficulty": "high"},
174
+ {"type": "extreme_aspect", "size": (1000, 50), "difficulty": "medium"},
175
+ {"type": "single_pixel", "size": (1, 1), "difficulty": "extreme"},
176
+ {"type": "corrupted_base64", "size": (224, 224), "difficulty": "extreme"}
177
+ ]
178
+
179
+ for j, edge_case in enumerate(edge_cases):
180
+ if edge_case["type"] == "corrupted_base64":
181
+ image = self.create_synthetic_image(224, 224, "random")
182
+ image_b64 = self.image_to_base64(image, "JPEG")
183
+ corrupted_b64 = image_b64[:-20] + "CORRUPTED_DATA"
184
+ api_request = self.create_api_request(corrupted_b64)
185
+ expected_response = {
186
+ "error": "Invalid image data",
187
+ "status": "failed"
188
+ }
189
+ else:
190
+ image = self.create_synthetic_image(
191
+ edge_case["size"][0], edge_case["size"][1], "random"
192
+ )
193
+ image_b64 = self.image_to_base64(image, "PNG")
194
+ api_request = self.create_api_request(image_b64, "image/png")
195
+ expected_response = self.create_expected_response()
196
+
197
+ record = {
198
+ "dataset_id": f"edge_{i:03d}",
199
+ "image_id": f"edge_{j:03d}",
200
+ "image_type": edge_case["type"],
201
+ "image_size": f"{edge_case['size'][0]}x{edge_case['size'][1]}",
202
+ "format": "PNG",
203
+ "media_type": "image/png",
204
+ "api_request": json.dumps(api_request),
205
+ "expected_response": json.dumps(expected_response),
206
+ "test_category": "edge_case",
207
+ "difficulty": edge_case["difficulty"]
208
+ }
209
+ records.append(record)
210
+
211
+ datasets.append({
212
+ "name": f"edge_case_{i:03d}",
213
+ "category": "edge_case",
214
+ "description": f"Edge case dataset {i+1} with challenging scenarios",
215
+ "records": records
216
+ })
217
+
218
+ return datasets
219
+
220
+ def generate_performance_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
221
+ """Generate performance benchmark datasets."""
222
+ datasets = []
223
+
224
+ for i in range(count):
225
+ batch_sizes = [1, 5, 10, 25, 50, 100]
226
+ batch_size = random.choice(batch_sizes)
227
+
228
+ records = []
229
+ for j in range(batch_size):
230
+ image = self.create_synthetic_image(224, 224, "random")
231
+ image_b64 = self.image_to_base64(image, "JPEG")
232
+ api_request = self.create_api_request(image_b64)
233
+ expected_response = self.create_expected_response()
234
+
235
+ record = {
236
+ "dataset_id": f"perf_{i:03d}",
237
+ "image_id": f"batch_{j:03d}",
238
+ "image_type": "performance_test",
239
+ "image_size": "224x224",
240
+ "format": "JPEG",
241
+ "media_type": "image/jpeg",
242
+ "api_request": json.dumps(api_request),
243
+ "expected_response": json.dumps(expected_response),
244
+ "test_category": "performance",
245
+ "difficulty": "normal",
246
+ "batch_size": batch_size,
247
+ "expected_max_latency_ms": batch_size * 100
248
+ }
249
+ records.append(record)
250
+
251
+ datasets.append({
252
+ "name": f"performance_test_{i:03d}",
253
+ "category": "performance",
254
+ "description": f"Performance dataset {i+1} with batch size {batch_size}",
255
+ "records": records
256
+ })
257
+
258
+ return datasets
259
+
260
+ def generate_model_comparison_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
261
+ """Generate datasets for comparing different models."""
262
+ datasets = []
263
+
264
+ model_types = [
265
+ "microsoft/resnet-18", "microsoft/resnet-50", "google/vit-base-patch16-224",
266
+ "facebook/convnext-tiny-224", "microsoft/swin-tiny-patch4-window7-224"
267
+ ]
268
+
269
+ for i in range(count):
270
+ # Same images tested across different model types
271
+ base_images = []
272
+ for _ in range(10): # 10 base images per comparison dataset
273
+ image = self.create_synthetic_image(224, 224, "geometric")
274
+ base_images.append(self.image_to_base64(image, "JPEG"))
275
+
276
+ records = []
277
+ for j, model in enumerate(model_types):
278
+ for k, image_b64 in enumerate(base_images):
279
+ api_request = self.create_api_request(image_b64)
280
+ expected_response = self.create_expected_response(model)
281
+
282
+ record = {
283
+ "dataset_id": f"comparison_{i:03d}",
284
+ "image_id": f"img_{k:03d}_model_{j}",
285
+ "image_type": "comparison_base",
286
+ "image_size": "224x224",
287
+ "format": "JPEG",
288
+ "media_type": "image/jpeg",
289
+ "api_request": json.dumps(api_request),
290
+ "expected_response": json.dumps(expected_response),
291
+ "test_category": "model_comparison",
292
+ "difficulty": "normal",
293
+ "model_type": model,
294
+ "comparison_group": k
295
+ }
296
+ records.append(record)
297
+
298
+ datasets.append({
299
+ "name": f"model_comparison_{i:03d}",
300
+ "category": "model_comparison",
301
+ "description": f"Model comparison dataset {i+1} testing {len(model_types)} models",
302
+ "records": records
303
+ })
304
+
305
+ return datasets
306
+
307
+ def save_dataset_to_parquet(self, dataset: Dict[str, Any]):
308
+ """Save a dataset to PyArrow Parquet format."""
309
+ records = dataset["records"]
310
+
311
+ # Convert to PyArrow table
312
+ table = pa.table({
313
+ "dataset_id": [r["dataset_id"] for r in records],
314
+ "image_id": [r["image_id"] for r in records],
315
+ "image_type": [r["image_type"] for r in records],
316
+ "image_size": [r["image_size"] for r in records],
317
+ "format": [r["format"] for r in records],
318
+ "media_type": [r["media_type"] for r in records],
319
+ "api_request": [r["api_request"] for r in records],
320
+ "expected_response": [r["expected_response"] for r in records],
321
+ "test_category": [r["test_category"] for r in records],
322
+ "difficulty": [r["difficulty"] for r in records],
323
+ # Optional fields with defaults
324
+ "batch_size": [r.get("batch_size", 1) for r in records],
325
+ "expected_max_latency_ms": [r.get("expected_max_latency_ms", 1000) for r in records],
326
+ "model_type": [r.get("model_type", "microsoft/resnet-18") for r in records],
327
+ "comparison_group": [r.get("comparison_group", 0) for r in records]
328
+ })
329
+
330
+ output_path = self.output_dir / f"{dataset['name']}.parquet"
331
+ pq.write_table(table, output_path)
332
+
333
+ # Save metadata as JSON
334
+ metadata = {
335
+ "name": dataset["name"],
336
+ "category": dataset["category"],
337
+ "description": dataset["description"],
338
+ "record_count": len(records),
339
+ "file_size_mb": round(output_path.stat().st_size / (1024 * 1024), 2),
340
+ "schema": [field.name for field in table.schema]
341
+ }
342
+
343
+ metadata_path = self.output_dir / f"{dataset['name']}_metadata.json"
344
+ with open(metadata_path, 'w') as f:
345
+ json.dump(metadata, f, indent=2)
346
+
347
+ def generate_all_datasets(self):
348
+ """Generate all 100 datasets."""
349
+ print(" Starting dataset generation...")
350
+
351
+ print("📊 Generating standard test datasets (25)...")
352
+ standard_datasets = self.generate_standard_datasets(25)
353
+ for dataset in standard_datasets:
354
+ self.save_dataset_to_parquet(dataset)
355
+
356
+ print("⚡ Generating edge case datasets (25)...")
357
+ edge_datasets = self.generate_edge_case_datasets(25)
358
+ for dataset in edge_datasets:
359
+ self.save_dataset_to_parquet(dataset)
360
+
361
+ print("🏁 Generating performance datasets (25)...")
362
+ performance_datasets = self.generate_performance_datasets(25)
363
+ for dataset in performance_datasets:
364
+ self.save_dataset_to_parquet(dataset)
365
+
366
+ print("🔄 Generating model comparison datasets (25)...")
367
+ comparison_datasets = self.generate_model_comparison_datasets(25)
368
+ for dataset in comparison_datasets:
369
+ self.save_dataset_to_parquet(dataset)
370
+
371
+ print(f"✅ Generated 100 datasets in {self.output_dir}/")
372
+
373
+ self.generate_summary()
374
+
375
+ def generate_summary(self):
376
+ """Generate a summary of all datasets."""
377
+ summary = {
378
+ "total_datasets": 100,
379
+ "categories": {
380
+ "standard": 25,
381
+ "edge_case": 25,
382
+ "performance": 25,
383
+ "model_comparison": 25
384
+ },
385
+ "dataset_info": [],
386
+ "usage_instructions": {
387
+ "loading": "Use pyarrow.parquet.read_table('dataset.parquet')",
388
+ "testing": "Run python scripts/test_datasets.py",
389
+ "api_endpoint": "POST /predict/resnet",
390
+ "request_format": "See api_request column in datasets"
391
+ }
392
+ }
393
+
394
+ # Add individual dataset info
395
+ for parquet_file in self.output_dir.glob("*.parquet"):
396
+ metadata_file = self.output_dir / f"{parquet_file.stem}_metadata.json"
397
+ if metadata_file.exists():
398
+ with open(metadata_file, 'r') as f:
399
+ metadata = json.load(f)
400
+ summary["dataset_info"].append(metadata)
401
+
402
+ summary_path = self.output_dir / "datasets_summary.json"
403
+ with open(summary_path, 'w') as f:
404
+ json.dump(summary, f, indent=2)
405
+
406
+ print(f"📋 Summary saved to {summary_path}")
407
+
408
+
409
+ if __name__ == "__main__":
410
+ generator = TestDatasetGenerator()
411
+ generator.generate_all_datasets()
scripts/model_download.bash ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ mkdir -p "models/microsoft/resnet-18"
2
+
3
+ python - <<'PY'
4
+ from huggingface_hub import snapshot_download
5
+ snapshot_download(
6
+ repo_id="microsoft/resnet-18",
7
+ local_dir="models/microsoft/resnet-18",
8
+ local_dir_use_symlinks=False # copies files; safer for containers
9
+ )
10
+ PY
scripts/test_datasets.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Dataset Tester for ML Inference Service
4
+
5
+ Tests the generated PyArrow datasets against the running ML inference service.
6
+ Validates API requests/responses and measures performance metrics.
7
+ """
8
+
9
+ import json
10
+ import time
11
+ import asyncio
12
+ import statistics
13
+ from pathlib import Path
14
+ from typing import Dict, List, Any, Optional
15
+ import argparse
16
+
17
+ import pyarrow.parquet as pq
18
+ import requests
19
+ import pandas as pd
20
+
21
+
22
+ class DatasetTester:
23
+ def __init__(self, base_url: str = "http://127.0.0.1:8000", datasets_dir: str = "test_datasets"):
24
+ self.base_url = base_url.rstrip('/')
25
+ self.datasets_dir = Path(datasets_dir)
26
+ self.endpoint = f"{self.base_url}/predict"
27
+ self.results = []
28
+
29
+ def load_dataset(self, dataset_path: Path) -> pd.DataFrame:
30
+ """Load a PyArrow dataset."""
31
+ table = pq.read_table(dataset_path)
32
+ return table.to_pandas()
33
+
34
+ def test_api_connection(self) -> bool:
35
+ """Test if the API is running and accessible."""
36
+ try:
37
+ response = requests.get(f"{self.base_url}/docs", timeout=5)
38
+ return response.status_code == 200
39
+ except requests.RequestException:
40
+ return False
41
+
42
+ def send_prediction_request(self, api_request_json: str) -> Dict[str, Any]:
43
+ """Send a single prediction request to the API."""
44
+ try:
45
+ request_data = json.loads(api_request_json)
46
+ start_time = time.time()
47
+
48
+ response = requests.post(
49
+ self.endpoint,
50
+ json=request_data,
51
+ headers={"Content-Type": "application/json"},
52
+ timeout=30
53
+ )
54
+
55
+ end_time = time.time()
56
+ latency_ms = (end_time - start_time) * 1000
57
+
58
+ return {
59
+ "success": response.status_code == 200,
60
+ "status_code": response.status_code,
61
+ "response": response.json() if response.status_code == 200 else response.text,
62
+ "latency_ms": round(latency_ms, 2),
63
+ "error": None
64
+ }
65
+
66
+ except requests.RequestException as e:
67
+ return {
68
+ "success": False,
69
+ "status_code": None,
70
+ "response": None,
71
+ "latency_ms": None,
72
+ "error": str(e)
73
+ }
74
+ except json.JSONDecodeError as e:
75
+ return {
76
+ "success": False,
77
+ "status_code": None,
78
+ "response": None,
79
+ "latency_ms": None,
80
+ "error": f"JSON decode error: {str(e)}"
81
+ }
82
+
83
+ def validate_response(self, actual_response: Dict[str, Any],
84
+ expected_response_json: str) -> Dict[str, Any]:
85
+ """Validate API response against expected response."""
86
+ try:
87
+ expected = json.loads(expected_response_json)
88
+
89
+ validation = {
90
+ "structure_valid": True,
91
+ "field_errors": []
92
+ }
93
+
94
+ # Check required fields exist
95
+ required_fields = ["prediction", "confidence", "predicted_label", "model", "mediaType"]
96
+ for field in required_fields:
97
+ if field not in actual_response:
98
+ validation["structure_valid"] = False
99
+ validation["field_errors"].append(f"Missing field: {field}")
100
+
101
+ # Validate field types
102
+ if "confidence" in actual_response:
103
+ if not isinstance(actual_response["confidence"], (int, float)):
104
+ validation["field_errors"].append("confidence must be numeric")
105
+ elif not (0 <= actual_response["confidence"] <= 1):
106
+ validation["field_errors"].append("confidence must be between 0 and 1")
107
+
108
+ if "predicted_label" in actual_response:
109
+ if not isinstance(actual_response["predicted_label"], int):
110
+ validation["field_errors"].append("predicted_label must be integer")
111
+
112
+ return validation
113
+
114
+ except json.JSONDecodeError:
115
+ return {
116
+ "structure_valid": False,
117
+ "field_errors": ["Invalid expected response JSON"]
118
+ }
119
+
120
+ def test_dataset(self, dataset_path: Path, max_samples: Optional[int] = None) -> Dict[str, Any]:
121
+ """Test a single dataset."""
122
+ print(f"📊 Testing dataset: {dataset_path.name}")
123
+
124
+ try:
125
+ df = self.load_dataset(dataset_path)
126
+ if max_samples:
127
+ df = df.head(max_samples)
128
+
129
+ results = {
130
+ "dataset_name": dataset_path.stem,
131
+ "total_samples": len(df),
132
+ "tested_samples": 0,
133
+ "successful_requests": 0,
134
+ "failed_requests": 0,
135
+ "validation_errors": 0,
136
+ "latencies_ms": [],
137
+ "errors": [],
138
+ "category": df['test_category'].iloc[0] if not df.empty else "unknown"
139
+ }
140
+
141
+ for idx, row in df.iterrows():
142
+ print(f" Testing sample {idx + 1}/{len(df)}", end="\r")
143
+
144
+ # Send API request
145
+ api_result = self.send_prediction_request(row['api_request'])
146
+ results["tested_samples"] += 1
147
+
148
+ if api_result["success"]:
149
+ results["successful_requests"] += 1
150
+ results["latencies_ms"].append(api_result["latency_ms"])
151
+
152
+ # Validate response structure
153
+ validation = self.validate_response(
154
+ api_result["response"],
155
+ row['expected_response']
156
+ )
157
+
158
+ if not validation["structure_valid"]:
159
+ results["validation_errors"] += 1
160
+ results["errors"].append({
161
+ "sample_id": row['image_id'],
162
+ "type": "validation_error",
163
+ "details": validation["field_errors"]
164
+ })
165
+
166
+ else:
167
+ results["failed_requests"] += 1
168
+ results["errors"].append({
169
+ "sample_id": row['image_id'],
170
+ "type": "request_failed",
171
+ "status_code": api_result["status_code"],
172
+ "error": api_result["error"]
173
+ })
174
+
175
+ # Calculate statistics
176
+ if results["latencies_ms"]:
177
+ results["avg_latency_ms"] = round(statistics.mean(results["latencies_ms"]), 2)
178
+ results["min_latency_ms"] = round(min(results["latencies_ms"]), 2)
179
+ results["max_latency_ms"] = round(max(results["latencies_ms"]), 2)
180
+ results["median_latency_ms"] = round(statistics.median(results["latencies_ms"]), 2)
181
+ else:
182
+ results.update({
183
+ "avg_latency_ms": None,
184
+ "min_latency_ms": None,
185
+ "max_latency_ms": None,
186
+ "median_latency_ms": None
187
+ })
188
+
189
+ results["success_rate"] = round(
190
+ results["successful_requests"] / results["tested_samples"] * 100, 2
191
+ ) if results["tested_samples"] > 0 else 0
192
+
193
+ print(f"\n ✅ Completed: {results['success_rate']}% success rate")
194
+ return results
195
+
196
+ except Exception as e:
197
+ print(f"\n ❌ Failed to test dataset: {str(e)}")
198
+ return {
199
+ "dataset_name": dataset_path.stem,
200
+ "error": str(e),
201
+ "success_rate": 0
202
+ }
203
+
204
+ def test_all_datasets(self, max_samples_per_dataset: Optional[int] = None,
205
+ category_filter: Optional[str] = None) -> Dict[str, Any]:
206
+ """Test all datasets or filtered by category."""
207
+ if not self.test_api_connection():
208
+ print("❌ API is not accessible. Please start the service first:")
209
+ print(" uvicorn main:app --reload")
210
+ return {"error": "API not accessible"}
211
+
212
+ print(f" Starting dataset testing against {self.endpoint}")
213
+
214
+ parquet_files = list(self.datasets_dir.glob("*.parquet"))
215
+ if not parquet_files:
216
+ print(f"❌ No datasets found in {self.datasets_dir}")
217
+ return {"error": "No datasets found"}
218
+
219
+ if category_filter:
220
+ parquet_files = [f for f in parquet_files if category_filter in f.name]
221
+
222
+ print(f" Found {len(parquet_files)} datasets to test")
223
+
224
+ all_results = []
225
+ start_time = time.time()
226
+
227
+ for dataset_file in parquet_files:
228
+ result = self.test_dataset(dataset_file, max_samples_per_dataset)
229
+ all_results.append(result)
230
+
231
+ end_time = time.time()
232
+ total_time = end_time - start_time
233
+
234
+ summary = self.generate_summary(all_results, total_time)
235
+
236
+ self.save_results(summary, all_results)
237
+
238
+ return summary
239
+
240
+ def generate_summary(self, results: List[Dict[str, Any]], total_time: float) -> Dict[str, Any]:
241
+ """Generate summary of all test results."""
242
+ successful_datasets = [r for r in results if r.get("success_rate", 0) > 0]
243
+ failed_datasets = [r for r in results if r.get("error") or r.get("success_rate", 0) == 0]
244
+
245
+ total_samples = sum(r.get("tested_samples", 0) for r in results)
246
+ total_successful = sum(r.get("successful_requests", 0) for r in results)
247
+ total_failed = sum(r.get("failed_requests", 0) for r in results)
248
+
249
+ all_latencies = []
250
+ for r in results:
251
+ all_latencies.extend(r.get("latencies_ms", []))
252
+
253
+ summary = {
254
+ "test_summary": {
255
+ "total_datasets": len(results),
256
+ "successful_datasets": len(successful_datasets),
257
+ "failed_datasets": len(failed_datasets),
258
+ "total_samples_tested": total_samples,
259
+ "total_successful_requests": total_successful,
260
+ "total_failed_requests": total_failed,
261
+ "overall_success_rate": round(
262
+ total_successful / total_samples * 100, 2
263
+ ) if total_samples > 0 else 0,
264
+ "total_test_time_seconds": round(total_time, 2)
265
+ },
266
+ "performance_metrics": {
267
+ "avg_latency_ms": round(statistics.mean(all_latencies), 2) if all_latencies else None,
268
+ "median_latency_ms": round(statistics.median(all_latencies), 2) if all_latencies else None,
269
+ "min_latency_ms": round(min(all_latencies), 2) if all_latencies else None,
270
+ "max_latency_ms": round(max(all_latencies), 2) if all_latencies else None,
271
+ "requests_per_second": round(
272
+ total_successful / total_time, 2
273
+ ) if total_time > 0 else 0
274
+ },
275
+ "category_breakdown": {},
276
+ "failed_datasets": [r["dataset_name"] for r in failed_datasets]
277
+ }
278
+
279
+ categories = {}
280
+ for result in results:
281
+ category = result.get("category", "unknown")
282
+ if category not in categories:
283
+ categories[category] = {
284
+ "count": 0,
285
+ "success_rates": [],
286
+ "avg_success_rate": 0
287
+ }
288
+ categories[category]["count"] += 1
289
+ categories[category]["success_rates"].append(result.get("success_rate", 0))
290
+
291
+ for category, data in categories.items():
292
+ data["avg_success_rate"] = round(
293
+ statistics.mean(data["success_rates"]), 2
294
+ ) if data["success_rates"] else 0
295
+
296
+ summary["category_breakdown"] = categories
297
+
298
+ return summary
299
+
300
+ def save_results(self, summary: Dict[str, Any], detailed_results: List[Dict[str, Any]]):
301
+ """Save test results to files."""
302
+ results_dir = Path("test_results")
303
+ results_dir.mkdir(exist_ok=True)
304
+
305
+ timestamp = int(time.time())
306
+
307
+ # Save summary
308
+ summary_path = results_dir / f"test_summary_{timestamp}.json"
309
+ with open(summary_path, 'w') as f:
310
+ json.dump(summary, f, indent=2)
311
+
312
+ # Save detailed results
313
+ detailed_path = results_dir / f"test_detailed_{timestamp}.json"
314
+ with open(detailed_path, 'w') as f:
315
+ json.dump(detailed_results, f, indent=2)
316
+
317
+ print(f" Results saved:")
318
+ print(f" Summary: {summary_path}")
319
+ print(f" Details: {detailed_path}")
320
+
321
+ def print_summary(self, summary: Dict[str, Any]):
322
+ """Print test summary to console."""
323
+ print("\n" + "="*60)
324
+ print("🏁 DATASET TESTING SUMMARY")
325
+ print("="*60)
326
+
327
+ ts = summary["test_summary"]
328
+ print(f"Datasets tested: {ts['total_datasets']}")
329
+ print(f"Successful datasets: {ts['successful_datasets']}")
330
+ print(f"Failed datasets: {ts['failed_datasets']}")
331
+ print(f"Total samples: {ts['total_samples_tested']}")
332
+ print(f"Overall success rate: {ts['overall_success_rate']}%")
333
+ print(f"Test duration: {ts['total_test_time_seconds']}s")
334
+
335
+ pm = summary["performance_metrics"]
336
+ if pm["avg_latency_ms"]:
337
+ print(f"\nPerformance:")
338
+ print(f" Avg latency: {pm['avg_latency_ms']}ms")
339
+ print(f" Median latency: {pm['median_latency_ms']}ms")
340
+ print(f" Min latency: {pm['min_latency_ms']}ms")
341
+ print(f" Max latency: {pm['max_latency_ms']}ms")
342
+ print(f" Requests/sec: {pm['requests_per_second']}")
343
+
344
+ print(f"\nCategory breakdown:")
345
+ for category, data in summary["category_breakdown"].items():
346
+ print(f" {category}: {data['count']} datasets, {data['avg_success_rate']}% avg success")
347
+
348
+ if summary["failed_datasets"]:
349
+ print(f"\nFailed datasets: {', '.join(summary['failed_datasets'])}")
350
+
351
+
352
+ def main():
353
+ parser = argparse.ArgumentParser(description="Test PyArrow datasets against ML inference service")
354
+ parser.add_argument("--base-url", default="http://127.0.0.1:8000", help="Base URL of the API")
355
+ parser.add_argument("--datasets-dir", default="test_datasets", help="Directory containing datasets")
356
+ parser.add_argument("--max-samples", type=int, help="Max samples per dataset to test")
357
+ parser.add_argument("--category", help="Filter datasets by category (standard, edge_case, performance, model_comparison)")
358
+ parser.add_argument("--quick", action="store_true", help="Quick test with max 5 samples per dataset")
359
+
360
+ args = parser.parse_args()
361
+
362
+ tester = DatasetTester(args.base_url, args.datasets_dir)
363
+
364
+ max_samples = args.max_samples
365
+ if args.quick:
366
+ max_samples = 5
367
+
368
+ results = tester.test_all_datasets(max_samples, args.category)
369
+
370
+ if "error" not in results:
371
+ tester.print_summary(results)
372
+
373
+ if results["test_summary"]["overall_success_rate"] > 90:
374
+ print("\n🎉 Excellent! API is working great with the datasets!")
375
+ elif results["test_summary"]["overall_success_rate"] > 70:
376
+ print("\n👍 Good! API works well, minor issues detected.")
377
+ else:
378
+ print("\n⚠️ Warning: Several issues detected. Check the detailed results.")
379
+
380
+
381
+ if __name__ == "__main__":
382
+ main()
submit.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ export DYFF_API_TOKEN=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiJiODBkZGY4NmUyNzk0MzI3OTNkOTlmNTk2MGE1MWRlZSIsInN1YiI6IkFjY291bnQvN2IxZWZiOWMwMDMyNGNkYzlhNDEyY2M4YmQ4ZjcwZTIiLCJuYmYiOjE3NjUzMTg1NjMsImV4cCI6MTc5Njg1NDU2MywiaWF0IjoxNzY1MzE4NTYzLCJncmFudHMiOlt7InJlc291cmNlcyI6WyIqIl0sImZ1bmN0aW9ucyI6WyJjb25zdW1lIiwiZGF0YSIsImdldCIsInF1ZXJ5Il0sImFjY291bnRzIjpbInB1YmxpYyJdLCJlbnRpdGllcyI6W119LHsicmVzb3VyY2VzIjpbImNoYWxsZW5nZXMiXSwiZnVuY3Rpb25zIjpbImdldCIsInF1ZXJ5Il0sImFjY291bnRzIjpbXSwiZW50aXRpZXMiOlsiZGM1MDlhOGM3NzFiNDkyYjkwYzQzMDEyZmRlOWEwNGYiXX0seyJyZXNvdXJjZXMiOlsidGVhbXMiXSwiZnVuY3Rpb25zIjpbIioiXSwiYWNjb3VudHMiOltdLCJlbnRpdGllcyI6WyIzYWUzMzdjMTY2MGQ0YjczOGM0NDNjMmMzN2ZjMWJjYSJdfSx7InJlc291cmNlcyI6WyJhcnRpZmFjdHMiLCJpbmZlcmVuY2VzZXJ2aWNlcyIsIm1vZGVscyIsInN1Ym1pc3Npb25zIl0sImZ1bmN0aW9ucyI6WyIqIl0sImFjY291bnRzIjpbIjdiMWVmYjljMDAzMjRjZGM5YTQxMmNjOGJkOGY3MGUyIl0sImVudGl0aWVzIjpbXX0seyJyZXNvdXJjZXMiOlsiKiJdLCJmdW5jdGlvbnMiOlsiY29uc3VtZSIsImRhdGEiLCJnZXQiLCJxdWVyeSJdLCJhY2NvdW50cyI6WyI3YjFlZmI5YzAwMzI0Y2RjOWE0MTJjYzhiZDhmNzBlMiJdLCJlbnRpdGllcyI6W119XSwic2VjcmV0IjoiM3p3MjFHUFI0WGRpenpIVTdnT0ZBdyJ9.bR71zVzqrRRsK_FZCYMGmTYCsQekz6LfmzM-EoyBOW0
2
+ venv/bin/python challenge-cli.py submit --team 3ae337c1660d4b738c443c2c37fc1bca --account Account/7b1efb9c00324cdc9a412cc8bd8f70e2 --task 567316f19d37490b97aa647e7017ef44 --service sida
test_main.http ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Test Prediction Endpoint
2
+ # Works with any model configured at startup (default: ResNet-18)
3
+
4
+ POST http://127.0.0.1:8000/predict
5
+ Content-Type: application/json
6
+
7
+ {
8
+ "image": {
9
+ "mediaType": "image/jpeg",
10
+ "data": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wCEAAkGBxMTEhUTExIWFhUVFhoYGBgYGBgYGBgXGBUYFxUYGBcYHSggGBolGxUXITEhJSorLi4uFx8zODMtNygtLisBCgoKDg0OGxAQGy0mHyUtLS0tMjAtLS8tLS0tLy0tLS01LS8vLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLf/AABEIAKoBKQMBIgACEQEDEQH/xAAbAAACAwEBAQAAAAAAAAAAAAAEBQIDBgEAB//EAEMQAAECAwYDBQUGBAUDBQAAAAECEQADIQQFEjFBUWFxkRMigaGxMlLB0fAGFBVCkuFigqLxI1Nyk9IHFkMzsrPC4v/EABkBAAMBAQEAAAAAAAAAAAAAAAECAwQABf/EAC4RAAIBAwMEAAUEAwEBAAAAAAABAgMRIRIxQQQTIlEUYXGRsTJCofAjgdHhBf/aAAwDAQACEQMRAD8Aw5s1A0U9kNocCynNNQ0DLs6tmEC63NOiSlpYJ2HCPJlNpDGXZ2rBgsQUlwKjOJdzNjQ6D03QtkyH08oum2cJGQi8yCPDrE2ejRotg89zd7oEkqAIoOkNhaAQE4RXWkLTZ1Pk4i2y5784TQi3flawzTYhsOkXy7MNh0gizFwKNBQlxS5ltkETZhsOgi5NmGw6CCBLbWCESHaBc7SwEWQbDpHZlhSQRQcWhl2MCrtKUKGJgPFx5QssrcvStGWUmI1XEp3IcHhExcWQYeNI1UmcFB0mhyjxDcYwOjWk9z2Y9bQgtjMKu6UipFdqExXaJiQHCR4t8Ic2iypxYi/wim1XbiY67fOO+GUWtTuxo9c6ibirL2ZS3SysvhFdhACrEraPoCLAQ1AeAELrXdgUXAz+sodudNaUkCn2K0nJt/7MYLPwiwWfhGgXdvARxFgHDpGqnGTyzLWnTV1ERfd+HpEhZeHpD37kNh0ixNjGw6RrUTyZVBALJw9IvRd/DzEPkWQbDpF6LKNh0htiblcRy7u/h80xb+Hjb0h6iQNh0iX3bl0g6iTiIfuHD0jyrAPd9Ieqs/LpEV2fl0jtQNJnV2Lh6RUbHXL0jQLkDh0ildmrp0g3BpEa7INvSK1WYbekPV2Xl0ildl5dI64bCc2UbekV/dht6Q6VZqftEBZuHlHINhKqzDaBewG0aM2T6aAfu300cwJASJymppB9ktYyUIVS1NR4mCTlGNHozWTUWaUkhwxHCDJNnABDZ8PLlGPkWlSS4Lcobyr6mEhyB4CFlSTGhXqQ5uhjOlJByPygKcEDIkHQ/N4eyLRKXXEAfntHLTIkqVgWCDnl8YKbW5F2bwI0CoCWL1NKdXhpJugZ5P6xcLtTKLoDvXEch4HOI2K/CtfZplYqs4LPxY5R177AsRstjmhR71DkaMPAwwtM5EpJMw5DqWp1hhMUhIYqCSqgBLExmL+uObNI7N1M7qJoBpXrAGVjky/5KsgrPLgc+BiVqtgWlAQCSMhw/M+1YrsH2KmYhjUGIdx5jnGl+zlx9kCpQ7xpvT94LaCUXTJUZbkkYXehd826bQDOXMWzoolRAUB7QNXz+Gkay19nKQ61JQjV9eAGpppFCLxshT/6yGpQltaUPEQEwGYu+7FKmv2ndDsRlUsQDGuFkYeEQVZbNZpa7QAMOalJ7z4jn56QnsP2yQcRmIIr3QPdrmSz6c+Edl7HYG8y7Bm3ebWK50pMtJVMUBuTlySIJkX5JKEqMxCSfyk1o/TLWMFfX2hmTz3gAgZJGQrmXzPGDGOQuTasaFd+2erEnCMsiasyd4BvO14UCYO6lTcSX32yeMcu0h8oqm2t+W0NpQY6th9OvNSi9AOIL+sdl20anyMZ2Vajx6xaZyuPUw6kkh+1J8D025O/kYnJtD7dDCIKL6+cMLGo7mGuuAdvhodygTF7GF8qcRv1ghNp4HTWDcjKkHS0RahML/vJaj668YNlzOefwjrk3AsVLMeMqJ9rTKOLUdjHC2BJkqIGRWClJOxiKJKn16GGTFaKOxilciGIknj0Me+7E6HpHXOsKjIeJCzfTw6s93l2LjiU5UgwWRCM1IOlQB6mElWhDdl6fS1KmyEKbrUfy67iFv4cv3fMfOH96XskA4UgUB7qwPzcBGc/Fuf6/wBozrqZS2R6a/8AmU0vJu/9+Rip1oxF2YxZKKgHYtvEJNNQecNJgRMQEpUArV6B9IDnosrYJxoKtd6s+vYImYOESC6wFNQpJYxwTYsmmYpQknZmgu219mXfw02qIMt17qUkIQmvvByrOgB0jNSpu5g2TawmteBByjrISzDECaokl6ZkuW59YayLaZSQqWAfeWdTqG1MJrbfJmJSmjJq7VLgO/SAjaTA3Gs0M7dbitZWrM+HlBtzXsqScSTXIg1BHEPWM+pZOcSwGGxawLPc+m3X9p5cxSUKwSw3tElnbdxhru8aeStKgClSSDkQQX3bePiMte8N5N9qlpTgLKANSKir0MSlT9Dxb5Dft7eOO04KYZQwgjUlioHkaeBjM9qSWakSt9sVNWVqqo5nc5aQPKVWHUtKsOqTkzUSr3TLsi7KHV2jKUrEQyu6WAaoZLHfwqgM1uMUrcxxElR0JibqpGiHSSfBbOtyiGyEU9sTSLU2RSvymCLFZ5aS6yS2QTqeY0iU+oSWDVT6J3ysFRsC2ds46LuUdINVe89+7LASKDuOQNKkRfOtE7D3+7i91OXAloxSrVebHo06FFbJipFiINRBiLKT/eJ2CzlRNSaah/jDuXYhuOkOq9lkWVBXwKE2SuXnBdnsphqixp3T0EXJsoH5k+Xzh41pSZJ04RQDKs5iZlQTMkbK6f3ir7qr3jmNDvzjTGczLKNNlQkq4Qxs0vff4RXKs5bM5H1hnIsZfPXbhFYzMdeMUsFfYCLAkQbLsf00W/cPrDDazHouL0JEXolwaiyAbdGiJnoScknxGsTl1MYlqfSTnsQRZNx5x6b2aNxVjQnR4WW+9yMnyOS29BCW1XiVHNQqD/6hOjRDu1Kn6cI9On0NOC8x/bb1GE10907xnbdeZJoR7Xu8IEnTFHVXUwJMB458Y6nSSeTa3pXieWomhbKF/ZQahJOp84F7NW584u4+iDqW3YjTKpoYjMs6hXfjB6btWQCx5/3ERN3L3HWB3EjzFQk9mAAGIrlQWuxrBZuhEVrlqGhju4hn08wTLSJpMWqA1cHqIj2UdrO7ByJBcR7OJYYOsXsssTOIiQmmKkiJEQXMMaBaFcYJRZphySS8UyJpTklJ5gxbJmzS+ENi2f5tEJ1JcG6l09Pm7+gQq5ptGzOlAY6izS5ZZQmLV7oSBVst4gi7p6iF1cUCipj4EmGdiuxQqucl3qy3roMqHkYyTrNLyl9jZToxviFimRZEGvZTArRD18SoU84cSrKEh3QkalQdjqyiO8fDSIi0KlDNCf8AUsk8SzH1iKbxQ7hBJOaiXHIA6dOUY5dyptsa04wLxbEthBKyNQwB8AwIiuyS5jv93wbLIDB9e9r4GDEXkGGGnBLJPj3fjEzeS1ZqNN8Pq0TjTnwvuLKoTk2KYoP95UAMwC/9TR1UgJDqWFPus1/lLwDNnKUXxEeMUT0Pq5EOumb3YnczgITbUk90N/K3wEW947QHd8glRzy2fWH8mzcPKNcIQhsQqyk3lgaLOdvOCJdm4ecGosx28oJlWQ/Qi6qsyShFAMuxjbzgmXZRtqNYPlWM/SYIFk+sMPrZnlJcC+XZBtvrxg+VZztrvF8uQw66cYqtiyAQHdtOcF1VFXIqDquxYCE5wPaLySMz/SreFk4zN1/1QHaJUz+PL+LeM0qjmz0aXSQgF2m9E+9v+Uwtn20nUdIiqyrOi/0mJpsCvdV+kw0IpF5aUtwNcpKs/jFYsiB9GGX4Ws+8P5TF8q4Vn85/QfnGmM8Gacobt3FZkp+nipVkB0840ku5CKu/8n7wSi6hsP0QVNIk6vzMim7VaDzEL/uC9vMR9LTZEj8g/SPlCb7un3R0EHvEXM+c2FUxGEpOTEO8Ml30A5nSpZIbJIdlKbP6yhR9mb8lJkNObuszt0AMJ/tPeAVaFdmRgDANkW1Fd3rGWOqU3Frbk1VZ0lTUl9jdzZkgueyALcT5GFc/CGZmJAAIZySwoDErivaTNnKRMKQOzSoKUcIBYlVSc6geEZ/7Q3oDaFJlkYEqABFRTDUF9wYWDk5abBmqUY6kzS9lKI70kPuFK9IpNyy1eyseX7kQTJnSVzxKdNZSVgghsRfECXzZojbkIRapVncPMBL7e5XckGnKFVbNvlcq6cd7823FFoukgsK/XKKVXadx6RprJYkrmzJSZgJlpSS71xO7DgwrxiJsQ7fsKYuzx5aO0N8SlgXsX5RlzYiDxi1CFJqAAd2BPnGpNzHSvIExSu7iPytzT84D6mEi8Ona5M0EKNMzwSPhBNnBBZSlJ4YRDU2Ze5+uERFmO3n+0c5pjxjpe7KJKUEe1lulPkKxcpKSw7x8UgfpaLUWfkD4GL5Vn5xJ6dyuqTB+yQP/ABjm7egiXZqOSQkDb94Yy7EOJ8IPs9l2ST1+BiMq0Yhs+RUmQTqfB/nBybtQ3t12NPM5Q0lyfeSG4GvDNUGy5LCiW8B/yNYzS6l8E5zSE9kupGalYtggYusFGyyxVsPMEK/pL9YbGS4qtXgUt5GKDYEh3HUv6RKVST3ZFVE3uLZIQSQkF9z8yTDGVLLRGXISMkjwEFo5RtoW04Fqy9FkmUX8YNlSoGlGLu1LUeNaZhmmwoIimdaEj83kflAmFfvq6n5xEWRRzV1eG1ehVTXLJqtg38jFPbklw0Eou/iOkFpsg2HSBbljOcVhC/t1cI92i9h9eMMvu42HQQDfN4S7NLC1MXmIQ1PzEP0SSekMvkS1E5Sl8IvQkwPeF5ypIlqUQ01YQDTUEvypCu134E3nJs7gJMhb/wCpRlqS+mST9GjJNitmiCYHm25KVhBLEh/MJHFyTAH2kv1EiWFBSSoqSAMQ94O7HJqPlWpEYL/qNfahaWlKDJCA4OIEM5Aagqa8oaMXIXk+qMYXXjeqZK5SVFhMKn4BKCX6gDxA1hJcX2pSbJLUopxBABGMAummR5fWUfO/t3fhnT3SaIDJdlMXeh2y8RBhTbYXg+x2S29oVhLdxQDuC7oSsHh7XlCx1RlfsF9qUgTjOVVRCwVKAySElLqNS4cAZAx7/veV7p/XHODTsdufHQmOhMNU3cWyiYu07RbWiSpsUlEdww3F2K2iYupW0LrQ/bYpSohi8XzLZMMxMzEcSQADSjUEMfwlW0d/CFbQrlG9x1GaRVdd8zZE1U1Ku8oMpw4IO7xcm95vaqnEurBhJ4FOH0iQudW0XJuhexictDyVjrWBlL+2K2kBSB/hPiId19wpS9eJeHU37YyGThlKqAVB6ineAJO7eEZhNzLP5TBAuJZ0jPKjSfBeFWquTbWW2WWYpKUzgSpRSBlUAHUcQHg+TdiFpC0EKSqoLpbyMYBNxL2MHWe6JoyKhpQkZVHmYzS6dftkaFXmbH8HGxjv4YBv0hbd8+1oThCiRhYYg7cQTWHVkvif/wCSUk8g2/PhEJUqvsPfnwVy7E2QEEy7Fw9YZ2S3IXmhSNnDv0hkiWDlEZQnyRn1UluhRIsja+vxg6VZ+MHJk8IsTL4QY0W9zJOvcDMvjFE2zgw4RLjk2VSkafg5OGtEo17MQfdhFiZEHrkxDA0dTi0rGjvXBhKixMqIzrUlOhPKF1pvZY9mWdc66UyjTGEmFKUhuhERtM9MtClrUAlIck6CMpeFstMwNiKB/CCH3rnCKdd0wu5WXDH2qh3Y71DxeNNndp8s2ls+1NklKwqnAl2LAlqO5YZct4Ro/wCoEvtS6T2dBSp7ruQMquOkZw3GrY/pMQmXGpsj+gxZU0DtpBdv/wCoK+3lTUSwyJakqSciVEEts2FPQxl74+0E+0Yu0UGVM7RgG72EIDcGEMplyK2P6DA825lNkf8AbMVUUhHH5C+9L8nTky0TJjpll0hgGO7gOTziidfE1U9M8realmUQNEhIozZAdILXc6joofyGKlXSt/ZV+gxRRQjT9A943zNm4cZBYN7IELpk/l0hv+EK2V+gxFdyK91X+2YooxJtTFcq2KbDSvCK5q4ci5le6r/bMVrudb5K/QYZJAakJRNPCKsf00O1XKvZX+2YD/DFbK/QYNkTambtCpDJ7qMhmeHOLEqle4nqP+UZWUSwPZpelRXLh9ZRei1N/wCM/pEeG1LhnsJx5RqCuSPyJ+vGJCdJ9xPUfOMwbfoEp/mGrR4Ty1EpfavzELpl7DePCNaJ0r/LH14x42mV/lj9JMZNE5fuS+RB+cUrt6sTdjLU2oA+Md25eznKK4NmLxkDNKRzYepif4vZx+VHKh9IxybSvMJlo/lST0Ecl3wx709XhLQPWC6UhdcTdS71le6j64GLk3vK2SeQf0EYtN4YgCkzCd+78IIVaJgFZiv6Q3nEnB+/5HsnsjYpvhGYQP0EeqYtTfKRmEeX/GMEu+QAxExR5GsTk3yMlBXiTlAdOYPA3qb+Gw8A/oIuRfnL+gesfPZN8SnoS2tEv5xem2k1lomKGrhKa0Pz6wHCa5OtTexv/wDuAjUf0x0faBW48SkRh02yaR7BT+k+hETlXo2eNxsk/ExKWtc/yN2qfo3Qv4+95hvKJi+1ak/XKPn0y8VBWJMuYRqWB8qfQgk24kPgmDoPKJt1fZ3Yp+jefjR1V6wIq/Ve8esY+VbD+VJB4gPEROXWp6o+CTGno9cpNSY0enprg1y75V73nFS73Vv6RkV2guAZiknbAFcq0aIptuYxTCRniCUjoTHo9tDaaadrfg1/4wrL5RA3qdvL9ox8y8xlWmZdB8gaxGVeKQCSZhzoEHTZtIdUzv8AEvRr/wAV3w+Xyjn4oNh5RlpFtJGLs5gG6gB5Kyga0W8g1mtwYK8wqCoXwc+2lf8Av8mu/FkuzDoflEF3mk0+BjLpthNHWTuHSPi0cNqdTEKp9VJEFROtD0jSKtaYqXORxhCbakaN/NHUWknceMGzKRVO+y+w77RHHzjgXLfXzhP95zd6cYslz3rUeMdlD6aT4GRmSuPnEwqUd/OFSpvPrHEzq5keMdZh0U/SGyjK2PnER2Wx84RqtR3V+qJCeWd1fqMdZ+zlGk+Byex4+cLOzlbHziEy0NqfBULfvX+r9UMk/Ys6dL0ZIWxQyrTeLBal69IhKkJcd6jfD5xebRhYhIfjyMT5wjw0na7ZMXiRQp+ucdC5inKEkDkdn+ucULt4OaU8+r7xZZrSTQZH6Hp5wko2zYeMtWNR4KmULqqPrXlBGJRGY55Goi6RIQc1e0aas523j0+6kl8KwSOeVM+HzhZVE3kpGjJLx/ILKQoHEC/jBcu0rz7hBOvjxgKfd6gHxUA9Sz+kUJJQWJB8Y5pSyBNxxZoZfeJhOaQ21ItwrA9tLcVE604QnmzgajOIDGugVTjlw8I7tpgda2Ms0EuclJZS3JDgvwGjHcxNFsSNH4kPV4SSrsmPmH5wZOu9ftEuNSCOGY5vCypxezKQqSX7bDgTyWKVpHBgPNmip11xTta5noBxhQuZLQQS6nBOeVWgj8aQAwQw2/fwiPbfCLdyPLt/sZWeyqaigdWdiYsCVADvl6ZVH75QpTfcshsJH8Tl/XhHpiisYkKpsebFzzIhHCXI6lHdMZTLTPTkkq8QGiUtU13Vm+poOJ4QhwzCogE0+vjFk2yziD3nEF0rehVVzyP13tLTmoE/wgepzio3lJPsqUkqo4YtzDRnvw3+J/D1ia7qAFZgJ2HzMaOkpU4yeRZ1qlsRG6J05NUWhKgdHIpuQdeFYOlqnBGIS0zNyApRrpUcNIVWW85VnHcl9/UqL6vTYwDO+1NoJJxnhwrG3S5vCF70Ka8pP82+5ok2tnM1SZb/AMAfhrtwgebbpJHcnYXZ2SEqPiPCEh+0S1pwzQFjiPj9ZCOJNmWaOgmjZDqfrKD2tOWL8UpYg0/rdP8A4OVSEKS6lTa/xAvxr9VidkkywD2ZNNVAO+wJ56CFGBKa41KDigOgj1qvJRDSkYQKd0acTvxhcvCYdcF5SWfux5PmMaHF/N8B9VgY29CQ6mfQMT8TCGVZZkwuVgEnJ6wzlXM5AKiQMz40A2o/SGUYrdgdarP9MfuEm1ylVDOdMuUeUvQM+0DruZOeJhpn8vSLkXaEd5CqjfOtI7VBcjqNV7xS+gR2E1nUgAAO7jIByWfaB1WutMJ8DHBeExL4ypSdRoUjMeIpFtmvCQst2AHgnSC9Vr2v9Du5G9k7P5/+BUjCpJ38YrmSpnsoQC31qeMBT7UUTJaUkgLOlAK69YYJtZSyiol9jX6pAeqI0ZwnvhrewFNVRlsFn2QHqPPjFaitKQWDGg+gYYXdYwpJmKUlWEsxqchk/OOTpImKKQQkJL8NqCG1q9kBwla/2BkzWFWbkYG+8J3HQxJSq5FtdoGxJ2HlDYJOvISWezlTMdhrR96cDF05JoFEUptlE7HbUppgHluD84vtlmxgYSBmaDyprnGVzle1sGVUo6bxd2BBKRo+tdo8be1EjyiKLGpiS4wjIg1+UVT7IUkBi5Zs6uzN1jrJiuU4rCsdm21SjyyrxixFrmJNaOIjJksRiHWj8oItWApISAKv5Gg+vhDtRatYVOd9Vy+TOWujUyD70o+WsGSruCmVMoQ3doHFCC7sz7NCpFrCEsKceORMQn3ip3JzGnn5wqptbFnVj+7JoBckpYSy2wkBVQp2YEUGdM4ptsrsckhSFZHNw+41YfTwkk3gt6Nrm2uecGKvHCkJUHoR7RwnUHiQ/nCSgykasLXSsyybbCr2UgNs+2bxFCphDVKX5OeeekUotYNEpDk15Elhs0XLmzBUghyQ+jUSacH84S1sI7VfLZZ+HKUKjL5hxuzGIS7rJFVMHbwicq1E0xF2y3ap8aeW9IhNvA5ca+EC8w6aW4bZ7slpHeLnUCtHz+EFpWEuE5El/HhzIhYm1qUXBAo/pX+3GCpE6gJfXhlmD4t55xKSk9y8HFYiW/iEsA/4ZDsDtw8jFsudiUCkNTfcf2iaES3cgU5bvkenXjFU+1Swe6APD639PCeHhJlVfdtWCLQZYAxHpxETVd0hdQshxqxY6NTjGVtVqCifKrx6VeC05aZMNeY5Rs6Sg75Znq9VBYawN7bcqkg4ZgUSQANcnJPllvCedY1DyLa1D5Qys98Es4Ph8R4mLkIQsE4iC54B3GVeJjQ5zg8meVKnVzTM+hwQAHPJz0hnc9nmziVAJCQQTiAD1zDhj46xcmRKSpy9D4M9c86U5R6dfobClDCgYACgNBxoWhu7qdooFKgoZqysvS5GiblkIUP8U0oRpic1yJoNOAiM+xWcuoKIO2LfLTb165mfbFhzoTvWh8opN4E6xVU78jvqqEXp0DW22lcs0A4ED64QEu2KDkEseNOXnHFXiFe3VhFlkCVqSkJcbZk5A8svCOUVHdGepUc5eEvodkWxY1LaudOUMLJayaACKvuiW9jMBjrmwPB9orl2RUouVu22r0zeFThIvDvQtyvwjTS7vnLQQlDhSWPeSKKGjmOTfs5JlygqapaV4mUAQQASWySdG11hR/3UpAwjHtRW1IHmfaRUwkKMwjOqqUECNOpf0is+o6d/N/MutFmmoUBNThSrKoJKd6E8IslTbOgsFnGMwQqgzzwttHDOXNBJKlEBkuSojgIW4CJhBQokCtC+mcPvuRk3Td4rD5ZdZDLShS0+0D3XBbxpxiSr3Qw7wxfmGFVICNkmZ99AGaWUMXFvrKHNguATEgqKUOHdSc/Es8GUoJXYlONaUrQR6xze1BlZlRyFMmVmaD2Yj+Bq9w/qT84Z3YuTJFUoUtyQruggEAMCa79Y7+Mo2H60xmdSSfisHofDwkk5tXMMZRSyq5Aqpk5I+n1eG9jVTF7oxB+B2gWRNbulRKSzk1O9a11iAnglTUchqlhvU6ROaMFJ6M+wiTOUrEFOS702Bb5w/tEhBUkqBKqTKHYJc58ATzfWF1lnABSZgTLIYggOXBdidDTxjiLwQoDchjt7Ayfd68hEZXbujVCyXk7iu87PMM3DTvKITsyS2goKQBOCgSkioLEEEeuUav76hRlBVGck8SAH5sBrp4Qjt6DjChValFRSBkfaSATU00rpWtLUqnBkr0UvJPkWoRiqSenX4dYNRZsSHb2AScy44cHIHMvwjsyUQvHgo7qGgr3gT4n9od2BEsDCafkUl83Se9UuAVFJ4FNKUg1KllcWlSu7CuVYFLQlgAaCgFQaBXGjvqOzVR6RdZrmUtQQ7AYiXYNhzetTkW4HaGlltQT7RGVDscTKqfHziSbWgFIo63LijBTBuqTlnEe5LNkalQji7BLHYhhwLSwYEkb0BPIpy3Lcoum5AEvVzpnQmn+kRG8pxKxhd1JGvCtcsh4Qmnz1FVSK6U+HOFScsjNxhgYlA2qfKrv5esD2yUCXAOp6s7vxc8o7ZrUAxI8csjWvjHJtrBDDQN4vsKcPCOV0xZOLRTMThLYvDXx216cY9KtZq78urvxf4xZZ5Yd3ajeH03SLLRJDLpn5OXo2XtK6jaHvF4FSksoGFrKqO1R1Km+MCTJpdnrBapTN3agnU+Bo20CTZRBA+q/Q6Q6UeCMnPkqJjqpxjix4/WvSPJlkkAaxoo7mare2Cci2KSXBZtdoMTeaiGJfy0aIi7y4QWL1JTWgr3S3u8IlarsSlScKlFL1cVAfccI0OzWToqrDY7KtAOYcnKpJenX94MRdoUArIsOAHmXf45RXZ7JhwlOpJdy9DRxpl/UXjgvc6ihJ3bfXPPzjJJO/gaYuKX+UBn2JYJdyBmQ5GcVCzFRPdbNtuXR4dIvVLEKYuK8/r1ipFqST3aZM5bgz+MOqk/QjoU2/1AlnulZFEkmn9SgB5ehhhd13TETC7AHCHerLq4fNtYKss8JV3gcScLDINuSztUnMZcostVuSSWZ3/YGgzbLgIDrTl42Lw6SnHyvsUXhZhgxOcRS+FxzAIGbJCiSNtc4RzZ60ird4AimjlvSDbVaaqc1Hu7voYoBSUlzWgFBo3TKKU04rJOtPVK0GL0uS+p+jDSw3aVEFSS3Q8Dyz6RRZZIWoUYFQDvUOWJ40rpGtu6XLSgSu0ol++SHLknydvCDVqaVgPSdL3H5bC24LFaSsGVgYLq5rwjXSrLJCiZ4Pbkd/CVYdGZqeyEwDfloTKCClQ9kksWybMiMlb78UpRIL5fnJ0EQSnVyjdKVLpvFu5qL6mCWwmUUoOnWj6tFUgTbQgIl4TgD1pTKI3TasSSpYxYTmouwYanIRGdakhalJUzn8pb00gW4NFv33w/uZ+8rCsTRQNgcsa0xQD2Y2PnGvvK9pakKUMGP2QAatu4HE04Rlfvi/8xX14xpgnJGCtCNOWHe/99i3EpgSaEZcK6RYhWEeNcsiND4QPNNTz+EMrMHUp69xH/sTEbXdjCgNdoJNSTRn2bLPlEEWgt0PL6YdIEUYOQkfdSpqicA+rYCWfZ4GlIXW2FWN1lKUkBRoHoK5V0OUXSAQtLkgjvAlmYZau+Ib5VgWznuA6gljqO8pq+A6RdOLyUE54lB9WxKLPs5MTaNMHdZDZs0hSiwepoRV1Emnj8I9aZmqQA4pq5FD4/tCtJ/xE/6k+ZDxK8Syqbn1ELoyP3LJs4u0u7u70+PwiCrSWSH9kU2qf3MAg/XhE5Jr4fAxTSjN3Gx7Z5y1oJBHcSqhNcIGJTJ1oGhbMUSSczn8fSIqWQlwSCCGOooTBt9qKbQSCQTLlmlKmUhz4uYRKzKSldZKl1YJGRr575e15cYrlDekUoUWVXb0Mdeo8PhDWFvyGpUd6joRrF8mYQCWOQ6d4eqRlxgeyLLgOWKFU0owFIqCi45/J4RxuUU7ZCp88sxDEE8GL1p4+cDTpgOv1lEbQs4U1NcPkFgeQA8BAhMGMRJ1Gy4qg6zz8gQzBvJn51MKgcuUEyVFs9B6xWCyLGWR3MWXJI2y8NB9ZwvmW04lMaORtQu0CS5hfM5bxeMpnAn/AOQRWK9jVZ6l44PC3KZnozecVTl4hRgBwH99YGRpyiSDn4xRoya28MkpBD50oXzHAvrQ+ceQ43Zq19Yutij2iq5rrxq9YLuI95R1CFkHYhmI2NTE5Ssrjwp6pKIQZ60DEpsS0guCxwmo7p56CrGphbNtRLkqJJ0isrJCiST3t/4jA69INNWY1WtKSVngNVKfvPSjDicmD0q9Yts1kVRbOAVUJbIZirn+0L5Ki+e0aa7pYJlAgEFC3BDg11GsNUk4op01ONV5+X5RxExGMKwBqMAGy4DlrA9pvIAkhwOQp4QAVFl10/8AsRCu0KL5xOMFctU6uUY+KsO7bfRmYQtSlJAZmTltSBOyC1dxgCHY5gBhVuMKwo7w5sZaUDr/APqKtaY4M8JutPzGFinTUy1nH3H7wDVoOEBTr2Dkd7oIIvlRcV0+MKJWU46hNDqKjKFgk1qZpqzkpKnF7f8AGyxdpcUNYH+8GBUGIxbY82VSUnk//9k="
11
+ }
12
+ }
upload_model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2025 UL Research Institutes
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import sys
5
+ import time
6
+ from pathlib import Path
7
+
8
+ import click
9
+
10
+ from dyff.client import Client
11
+ from dyff.schema.platform import *
12
+ from dyff.schema.requests import *
13
+
14
+ from app.api.models import PredictionResponse
15
+
16
+ # ----------------------------------------------------------------------------
17
+
18
+ WORKDIR = Path(__file__).resolve().parent
19
+
20
+
21
+ @click.command()
22
+ @click.option(
23
+ "--account",
24
+ type=str,
25
+ required=True,
26
+ help="Your account ID",
27
+ )
28
+ @click.option(
29
+ "--name",
30
+ type=str,
31
+ required=True,
32
+ help="The name of your detector model. For display and querying purposes only.",
33
+ )
34
+ @click.option(
35
+ "--image",
36
+ type=str,
37
+ required=True,
38
+ help="The Docker image to upload. Must exist in your local Docker deamon.",
39
+ )
40
+ @click.option(
41
+ "--endpoint",
42
+ type=str,
43
+ default="predict",
44
+ help="The endpoint to call on your model to make a prediction.",
45
+ )
46
+ def main(account: str, name: str, image: str, endpoint: str) -> None:
47
+ dyffapi = Client()
48
+
49
+ # You can set these to a known ID to skip that step
50
+ artifact_id = None
51
+ service_id = None
52
+
53
+ # Upload the image
54
+ if artifact_id is None:
55
+ # Create an Artifact resource
56
+ artifact = dyffapi.artifacts.create(ArtifactCreateRequest(account=account))
57
+ click.echo(f"artifact_id = \"{artifact.id}\"")
58
+ time.sleep(5)
59
+ # Push the image from the local Docker daemon
60
+ dyffapi.artifacts.push(artifact, source=f"docker-daemon:{image}")
61
+ time.sleep(5)
62
+ # Indicate that we're done pushing
63
+ dyffapi.artifacts.finalize(artifact.id)
64
+ else:
65
+ artifact = dyffapi.artifacts.get(artifact_id)
66
+ assert artifact is not None
67
+
68
+ # Create a runnable InferenceService
69
+ if service_id is None:
70
+ # Don't change this
71
+ service_request = InferenceServiceCreateRequest(
72
+ account=account,
73
+ name=name,
74
+ model=None,
75
+ runner=InferenceServiceRunner(
76
+ kind=InferenceServiceRunnerKind.CONTAINER,
77
+ imageRef=EntityIdentifier.of(artifact),
78
+ resources=ModelResources(),
79
+ ),
80
+ interface=InferenceInterface(
81
+ endpoint=endpoint,
82
+ outputSchema=DataSchema.make_output_schema(PredictionResponse),
83
+ ),
84
+ )
85
+ service = dyffapi.inferenceservices.create(service_request)
86
+ click.echo(f"service_id = \"{service.id}\"")
87
+ else:
88
+ service = dyffapi.inferenceservices.get(service_id)
89
+ assert service is not None
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
upload_submission.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ export DOCKER_HOST=$DOCKER_HOST
2
+ echo DOCKER_HOST is $DOCKER_HOST
3
+ export DYFF_API_TOKEN=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJqdGkiOiJiODBkZGY4NmUyNzk0MzI3OTNkOTlmNTk2MGE1MWRlZSIsInN1YiI6IkFjY291bnQvN2IxZWZiOWMwMDMyNGNkYzlhNDEyY2M4YmQ4ZjcwZTIiLCJuYmYiOjE3NjUzMTg1NjMsImV4cCI6MTc5Njg1NDU2MywiaWF0IjoxNzY1MzE4NTYzLCJncmFudHMiOlt7InJlc291cmNlcyI6WyIqIl0sImZ1bmN0aW9ucyI6WyJjb25zdW1lIiwiZGF0YSIsImdldCIsInF1ZXJ5Il0sImFjY291bnRzIjpbInB1YmxpYyJdLCJlbnRpdGllcyI6W119LHsicmVzb3VyY2VzIjpbImNoYWxsZW5nZXMiXSwiZnVuY3Rpb25zIjpbImdldCIsInF1ZXJ5Il0sImFjY291bnRzIjpbXSwiZW50aXRpZXMiOlsiZGM1MDlhOGM3NzFiNDkyYjkwYzQzMDEyZmRlOWEwNGYiXX0seyJyZXNvdXJjZXMiOlsidGVhbXMiXSwiZnVuY3Rpb25zIjpbIioiXSwiYWNjb3VudHMiOltdLCJlbnRpdGllcyI6WyIzYWUzMzdjMTY2MGQ0YjczOGM0NDNjMmMzN2ZjMWJjYSJdfSx7InJlc291cmNlcyI6WyJhcnRpZmFjdHMiLCJpbmZlcmVuY2VzZXJ2aWNlcyIsIm1vZGVscyIsInN1Ym1pc3Npb25zIl0sImZ1bmN0aW9ucyI6WyIqIl0sImFjY291bnRzIjpbIjdiMWVmYjljMDAzMjRjZGM5YTQxMmNjOGJkOGY3MGUyIl0sImVudGl0aWVzIjpbXX0seyJyZXNvdXJjZXMiOlsiKiJdLCJmdW5jdGlvbnMiOlsiY29uc3VtZSIsImRhdGEiLCJnZXQiLCJxdWVyeSJdLCJhY2NvdW50cyI6WyI3YjFlZmI5YzAwMzI0Y2RjOWE0MTJjYzhiZDhmNzBlMiJdLCJlbnRpdGllcyI6W119XSwic2VjcmV0IjoiM3p3MjFHUFI0WGRpenpIVTdnT0ZBdyJ9.bR71zVzqrRRsK_FZCYMGmTYCsQekz6LfmzM-EoyBOW0
4
+ /media/NAS/USERS/shahid/WACV/example-submission/venv/bin/python challenge-cli.py upload-submission --name "dino" --account 7b1efb9c00324cdc9a412cc8bd8f70e2 --image safe-challenge-2025/example-submission:latest
5
+ #
utils/My_version.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2
+ # Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
3
+ #
4
+ # All rights reserved.
5
+ # This work should only be used for nonprofit purposes.
6
+ #
7
+ # By downloading and/or using any of these files, you implicitly agree to all the
8
+ # terms of the license, as specified in the document LICENSE.txt
9
+ # (included in this package) and online at
10
+ # http://www.grip.unina.it/download/LICENSE_OPEN.txt
11
+
12
+ """
13
+ Created in September 2022
14
+ @author: fabrizio.guillaro
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+
21
+
22
+
23
+ class CrossEntropy(nn.Module):
24
+ def __init__(self, ignore_label=-1, weight=None):
25
+ super(CrossEntropy, self).__init__()
26
+ self.ignore_label = ignore_label
27
+ self.criterion = nn.CrossEntropyLoss(weight=weight,
28
+ ignore_index=ignore_label)
29
+
30
+ def forward(self, score, target):
31
+ ph, pw = score.size(2), score.size(3)
32
+ h, w = target.size(1), target.size(2)
33
+ if ph != h or pw != w:
34
+ score = F.upsample(
35
+ input=score, size=(h, w), mode='bilinear')
36
+
37
+ loss = self.criterion(score, target)
38
+ return loss
39
+
40
+
41
+
42
+ class DiceLoss(nn.Module):
43
+ def __init__(self, ignore_label=-1, smooth=1, exponent=2): #because padding adds -1 to the targets
44
+ super(DiceLoss, self).__init__()
45
+ self.ignore_index = ignore_label
46
+ self.smooth = smooth
47
+ self.exponent = exponent
48
+
49
+ def dice_loss(self, pred, target, valid_mask):
50
+ assert pred.shape[0] == target.shape[0]
51
+ total_loss = 0
52
+ num_classes = pred.shape[1]
53
+ for i in range(num_classes):
54
+ if i != self.ignore_index:
55
+ dice_loss = self.binary_dice_loss(
56
+ pred[:, i],
57
+ target[..., i],
58
+ valid_mask=valid_mask,)
59
+ total_loss += dice_loss
60
+ return total_loss / num_classes
61
+
62
+ def binary_dice_loss(self, pred, target, valid_mask):
63
+ assert pred.shape[0] == target.shape[0]
64
+ pred = pred.reshape(pred.shape[0], -1)
65
+ target = target.reshape(target.shape[0], -1)
66
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
67
+
68
+ num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
69
+ den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
70
+
71
+ dice = num / den
72
+ dice = torch.mean(dice)
73
+ return 1 - dice
74
+
75
+ def forward(self, score, target):
76
+ ph, pw = score.size(2), score.size(3)
77
+ h, w = target.size(1), target.size(2)
78
+ if ph != h or pw != w:
79
+ score = F.upsample(
80
+ input=score, size=(h, w), mode='bilinear')
81
+
82
+ score = F.softmax(score,dim=1)
83
+ num_classes = score.shape[1]
84
+
85
+ one_hot_target = F.one_hot(
86
+ torch.clamp(target.long(), 0, num_classes - 1),
87
+ num_classes=num_classes)
88
+ valid_mask = (target != self.ignore_index).long()
89
+
90
+ loss = self.dice_loss(score, one_hot_target, valid_mask)
91
+ return loss
92
+
93
+
94
+ class BinaryDiceLoss(nn.Module):
95
+ def __init__(self, smooth=1, exponent=2, ignore_label=-1): #because padding adds -1 to the targets
96
+ super(BinaryDiceLoss, self).__init__()
97
+ self.ignore_index = ignore_label
98
+ self.smooth = smooth
99
+ self.exponent = exponent
100
+
101
+ def binary_dice_loss(self, pred, target, valid_mask):
102
+ assert pred.shape[0] == target.shape[0]
103
+ pred = pred.reshape(pred.shape[0], -1)
104
+ target = target.reshape(target.shape[0], -1)
105
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
106
+
107
+ num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
108
+ den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
109
+
110
+ dice = num / den
111
+ dice = torch.mean(dice)
112
+ return 1 - dice
113
+
114
+ def forward(self, score, target):
115
+ ph, pw = score.size(2), score.size(3)
116
+ h, w = target.size(1), target.size(2)
117
+ if ph != h or pw != w:
118
+ score = F.upsample(
119
+ input=score, size=(h, w), mode='bilinear')
120
+
121
+ score = F.softmax(score,dim=1)
122
+ num_classes = score.shape[1]
123
+
124
+ one_hot_target = F.one_hot(
125
+ torch.clamp(target.long(), 0, num_classes - 1),
126
+ num_classes=num_classes)
127
+ valid_mask = (target != self.ignore_index).long()
128
+
129
+ loss = self.binary_dice_loss(
130
+ score[:, 1],
131
+ one_hot_target[..., 1],
132
+ valid_mask)
133
+ return loss
134
+
135
+
136
+ class DiceEntropyLoss(nn.Module):
137
+ def __init__(self, smooth=1, exponent=2, ignore_label=-1, weight=None): #because padding adds -1 to the targets
138
+ super(DiceEntropyLoss, self).__init__()
139
+ self.ignore_label = ignore_label
140
+ self.smooth = smooth
141
+ self.exponent = exponent
142
+ self.cross_entropy = nn.CrossEntropyLoss(weight=weight,
143
+ ignore_index=ignore_label)
144
+
145
+ def binary_dice_loss(self, pred, target, valid_mask):
146
+ print(pred.shape, target.shape, "this is second list")
147
+ assert pred.shape[0] == target.shape[0]
148
+ pred = pred.reshape(pred.shape[0], -1)
149
+ target = target.reshape(target.shape[0], -1)
150
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
151
+
152
+ # print(pred.shape, target.shape)
153
+ num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
154
+ den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
155
+
156
+ dice = num / den
157
+ dice = torch.mean(dice)
158
+ return 1 - dice
159
+
160
+ def forward(self, score, target):
161
+ ph, pw = score.size(2), score.size(3)
162
+ h, w = target.size(1), target.size(2)
163
+ # if ph != h or pw != w:
164
+ # score = F.upsample(
165
+ # input=score, size=(h, w), mode='bilinear')
166
+ CE_loss = self.cross_entropy(score, target)
167
+
168
+
169
+ score = F.softmax(score,dim=1)
170
+ num_classes = score.shape[1]
171
+
172
+ one_hot_target = F.one_hot(
173
+ torch.clamp(target.long(), 0, num_classes - 1),
174
+ num_classes=num_classes)
175
+ valid_mask = (target != self.ignore_label).long()
176
+ # print(score.shape,one_hot_target.shape)
177
+ dice_loss = self.binary_dice_loss(
178
+ score[:, 1],
179
+ one_hot_target[:,1][..., 1],
180
+ valid_mask)
181
+
182
+ return 0.3*CE_loss + 0.7*dice_loss
183
+
184
+
185
+
186
+
187
+ class FocalLoss(nn.Module):
188
+
189
+ def __init__(self, alpha=0.25, gamma=2., ignore_label=-1): #alpha 0.25, gamma=2.
190
+ super(FocalLoss, self).__init__()
191
+ self.alpha=alpha
192
+ self.gamma= gamma
193
+ self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label, reduction="none")
194
+
195
+ def forward(self, score, target):
196
+ ph, pw = score.size(2), score.size(3)
197
+ h, w = target.size(1), target.size(2)
198
+ if ph != h or pw != w:
199
+ score = F.upsample(
200
+ input=score, size=(h, w), mode='bilinear')
201
+
202
+ ce_loss = self.criterion(score, target)
203
+ pt = torch.exp(-ce_loss)
204
+ f_loss = self.alpha * (1-pt)**self.gamma * ce_loss
205
+ return f_loss.mean()
206
+
207
+
utils/losses.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2
+ # Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
3
+ #
4
+ # All rights reserved.
5
+ # This work should only be used for nonprofit purposes.
6
+ #
7
+ # By downloading and/or using any of these files, you implicitly agree to all the
8
+ # terms of the license, as specified in the document LICENSE.txt
9
+ # (included in this package) and online at
10
+ # http://www.grip.unina.it/download/LICENSE_OPEN.txt
11
+
12
+ """
13
+ Created in September 2022
14
+ @author: fabrizio.guillaro
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+
21
+
22
+
23
+ class CrossEntropy(nn.Module):
24
+ def __init__(self, ignore_label=-1, weight=None):
25
+ super(CrossEntropy, self).__init__()
26
+ self.ignore_label = ignore_label
27
+ self.criterion = nn.CrossEntropyLoss(weight=weight,
28
+ ignore_index=ignore_label)
29
+
30
+ def forward(self, score, target):
31
+ ph, pw = score.size(2), score.size(3)
32
+ h, w = target.size(1), target.size(2)
33
+ if ph != h or pw != w:
34
+ score = F.upsample(
35
+ input=score, size=(h, w), mode='bilinear')
36
+
37
+ loss = self.criterion(score, target)
38
+ return loss
39
+
40
+
41
+
42
+ class DiceLoss(nn.Module):
43
+ def __init__(self, ignore_label=-1, smooth=1, exponent=2): #because padding adds -1 to the targets
44
+ super(DiceLoss, self).__init__()
45
+ self.ignore_index = ignore_label
46
+ self.smooth = smooth
47
+ self.exponent = exponent
48
+
49
+ def dice_loss(self, pred, target, valid_mask):
50
+ assert pred.shape[0] == target.shape[0]
51
+ total_loss = 0
52
+ num_classes = pred.shape[1]
53
+ for i in range(num_classes):
54
+ if i != self.ignore_index:
55
+ dice_loss = self.binary_dice_loss(
56
+ pred[:, i],
57
+ target[..., i],
58
+ valid_mask=valid_mask,)
59
+ total_loss += dice_loss
60
+ return total_loss / num_classes
61
+
62
+ def binary_dice_loss(self, pred, target, valid_mask):
63
+ assert pred.shape[0] == target.shape[0]
64
+ pred = pred.reshape(pred.shape[0], -1)
65
+ target = target.reshape(target.shape[0], -1)
66
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
67
+
68
+ num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
69
+ den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
70
+
71
+ dice = num / den
72
+ dice = torch.mean(dice)
73
+ return 1 - dice
74
+
75
+ def forward(self, score, target):
76
+ ph, pw = score.size(2), score.size(3)
77
+ h, w = target.size(1), target.size(2)
78
+ if ph != h or pw != w:
79
+ score = F.upsample(
80
+ input=score, size=(h, w), mode='bilinear')
81
+
82
+ score = F.softmax(score,dim=1)
83
+ num_classes = score.shape[1]
84
+
85
+ one_hot_target = F.one_hot(
86
+ torch.clamp(target.long(), 0, num_classes - 1),
87
+ num_classes=num_classes)
88
+ valid_mask = (target != self.ignore_index).long()
89
+
90
+ loss = self.dice_loss(score, one_hot_target, valid_mask)
91
+ return loss
92
+
93
+
94
+
95
+ class BinaryDiceLoss(nn.Module):
96
+ def __init__(self, smooth=1, exponent=2, ignore_label=-1): #because padding adds -1 to the targets
97
+ super(BinaryDiceLoss, self).__init__()
98
+ self.ignore_index = ignore_label
99
+ self.smooth = smooth
100
+ self.exponent = exponent
101
+
102
+ def binary_dice_loss(self, pred, target, valid_mask):
103
+ assert pred.shape[0] == target.shape[0]
104
+ print(pred.shape, target.shape)
105
+ pred = pred.reshape(pred.shape[0], -1)
106
+ target = target.reshape(target.shape[0], -1)
107
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
108
+ print(pred.shape, target.shape)
109
+ num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
110
+ den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
111
+
112
+ dice = num / den
113
+ dice = torch.mean(dice)
114
+ return 1 - dice
115
+
116
+ def forward(self, score, target):
117
+
118
+ ph, pw = score.size(2), score.size(3)
119
+ h, w = target.size(2), target.size(3)
120
+ if ph != h or pw != w:
121
+ score = F.upsample(
122
+ input=score, size=(h, w), mode='bilinear')
123
+
124
+ score = F.softmax(score,dim=1)
125
+ num_classes = score.shape[1]
126
+
127
+ one_hot_target = F.one_hot(
128
+ torch.clamp(target.long(), 0, num_classes - 1),
129
+ num_classes=num_classes)
130
+ valid_mask = (target != self.ignore_index).long()
131
+
132
+ loss = self.binary_dice_loss(
133
+ score[:, 1],
134
+ one_hot_target[..., 1],
135
+ valid_mask)
136
+ return loss
137
+
138
+ def create_target_from_mask_and_label(mask, data_label):
139
+ """
140
+ Convert binary mask to class-labeled target based on data_label.
141
+
142
+ Args:
143
+ mask: B H W with values 0 (black/background) or 1 (white/foreground)
144
+ data_label: B×1 tensor or B tensor with values [0, 1, 2, 3]
145
+ - 0: background (no edit)
146
+ - 1: physical edit (Photoshop)
147
+ - 2: synthetic AI edit
148
+ - 3: other edit type
149
+
150
+ Returns:
151
+ target: B H W with values [0, 1, 2, 3]
152
+ - 0: unedited pixels (mask == 0)
153
+ - 1, 2, 3: edited pixels with their respective class labels
154
+ """
155
+
156
+ # Handle if mask has channel dimension
157
+ if mask.dim() == 4: # B×1×H×W
158
+ mask = mask.squeeze(1) # B×H×W
159
+
160
+ # Handle if data_label has extra dimensions
161
+ if data_label.dim() > 1:
162
+ data_label = data_label.squeeze() # B
163
+
164
+ B, H, W = mask.shape
165
+
166
+ # Initialize target with zeros (background class)
167
+ target = torch.zeros(B, H, W, dtype=torch.long, device=mask.device)
168
+
169
+ # For each sample in batch
170
+ for b in range(B):
171
+ # Get the class label for this sample
172
+ class_label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
173
+
174
+ # Where mask is white (1), set the target to the class label
175
+ # Where mask is black (0), keep target as 0 (background)
176
+ target[b][mask[b] == 1] = class_label
177
+
178
+ return target
179
+
180
+
181
+ def debug_target_creation(target, data_label, batch_size=4):
182
+ """
183
+ Debug function to print data_label and target mapping before and after conversion.
184
+
185
+ Args:
186
+ target: Binary mask B×H×W or B×1×H×W with values 0 or 1
187
+ data_label: B tensor with class labels [0, 1, 2, 3]
188
+ """
189
+
190
+ print("="*80)
191
+ print("DEBUGGING TARGET CREATION")
192
+ print("="*80)
193
+
194
+ # Print original inputs
195
+ print("\n--- BEFORE CONVERSION ---")
196
+ print(f"Data Label shape: {data_label.shape}")
197
+ print(f"Data Label values: {data_label}")
198
+ print(f"Data Label dtype: {data_label.dtype}")
199
+
200
+ print(f"\nTarget (mask) shape: {target.shape}")
201
+ print(f"Target (mask) unique values: {torch.unique(target)}")
202
+ print(f"Target (mask) dtype: {target.dtype}")
203
+
204
+ # Print per-sample details BEFORE
205
+ print("\n--- PER-SAMPLE BREAKDOWN (BEFORE) ---")
206
+ if target.dim() == 4: # B×1×H×W
207
+ target_2d = target.squeeze(1)
208
+ else:
209
+ target_2d = target
210
+
211
+ B = target_2d.shape[0]
212
+ for b in range(min(B, batch_size)):
213
+ edited_pixels = (target_2d[b] == 1).sum().item()
214
+ total_pixels = target_2d[b].numel()
215
+ label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
216
+ print(f" Sample {b}: Label={label}, Edited pixels={edited_pixels}/{total_pixels}")
217
+
218
+ # Create target
219
+ target_converted = create_target_from_mask_and_label(target, data_label)
220
+
221
+ # Print AFTER conversion
222
+ print("\n--- AFTER CONVERSION ---")
223
+ print(f"Target (converted) shape: {target_converted.shape}")
224
+ print(f"Target (converted) unique values: {torch.unique(target_converted)}")
225
+ print(f"Target (converted) dtype: {target_converted.dtype}")
226
+
227
+ # Print per-sample details AFTER
228
+ print("\n--- PER-SAMPLE BREAKDOWN (AFTER) ---")
229
+ for b in range(min(B, batch_size)):
230
+ label = data_label[b].item() if data_label.dim() > 0 else data_label.item()
231
+
232
+ # Count pixels for each class
233
+ class_counts = {}
234
+ for class_id in range(4):
235
+ count = (target_converted[b] == class_id).sum().item()
236
+ class_counts[class_id] = count
237
+
238
+ print(f" Sample {b}:")
239
+ print(f" Label (expected): {label}")
240
+ print(f" Class distribution: {class_counts}")
241
+
242
+ # Verify correctness
243
+ if label == 0:
244
+ # All pixels should be background (0)
245
+ if class_counts[0] == target_converted[b].numel():
246
+ print(f" ✓ CORRECT: All pixels are class 0 (background)")
247
+ else:
248
+ print(f" ✗ ERROR: Expected all pixels to be 0, but got {class_counts}")
249
+ else:
250
+ # Non-background pixels should have the label
251
+ if class_counts[label] > 0:
252
+ print(f" ✓ CORRECT: Found {class_counts[label]} pixels with class {label}")
253
+ else:
254
+ print(f" ✗ ERROR: Expected class {label} pixels but found none")
255
+
256
+ print("\n" + "="*80)
257
+
258
+ return target_converted
259
+
260
+ class MultiClassDiceEntropyLoss(nn.Module):
261
+ """
262
+ Multi-class segmentation loss combining Dice and CrossEntropy.
263
+ Supports classes: 0 (background), 1, 2, 3
264
+ """
265
+ def __init__(self, num_classes=4, smooth=1e-5, dice_weight=0.5, ce_weight=0.5,
266
+ ignore_index=-1):
267
+ super(MultiClassDiceEntropyLoss, self).__init__()
268
+ self.num_classes = num_classes
269
+ self.smooth = smooth
270
+ self.dice_weight = dice_weight
271
+ self.ce_weight = ce_weight
272
+ self.ignore_index = ignore_index
273
+
274
+ # CrossEntropy loss
275
+ self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
276
+
277
+ def dice_loss(self, pred, target, valid_mask=None):
278
+ """
279
+ Compute Dice loss per class and average
280
+
281
+ pred: B C H W (softmax probabilities)
282
+ target: B H W (class indices 0-3)
283
+ valid_mask: B H W (1 for valid, 0 for ignore)
284
+ """
285
+ dice_losses = []
286
+
287
+ for class_id in range(self.num_classes):
288
+ # One-hot encode for this class
289
+ pred_class = pred[:, class_id, :, :] # B×H×W
290
+ target_class = (target == class_id).float() # B×H×W
291
+
292
+ # Flatten
293
+ pred_flat = pred_class.reshape(-1)
294
+ target_flat = target_class.reshape(-1)
295
+
296
+ # Apply valid mask if provided
297
+ if valid_mask is not None:
298
+ valid_flat = valid_mask.reshape(-1)
299
+ pred_flat = pred_flat * valid_flat
300
+ target_flat = target_flat * valid_flat
301
+
302
+ # Dice computation
303
+ intersection = torch.sum(pred_flat * target_flat)
304
+ union = torch.sum(pred_flat) + torch.sum(target_flat)
305
+
306
+ dice = (2 * intersection + self.smooth) / (union + self.smooth)
307
+ dice_losses.append(1 - dice)
308
+
309
+ return torch.mean(torch.stack(dice_losses))
310
+
311
+ def forward(self, score, target, data_label):
312
+ """
313
+ pred: B 1 H W (U-Net output, raw logits)
314
+ target: B H W (class labels: 0, 1, 2, or 3)
315
+ """
316
+ # Handle if target has channel dimension
317
+ if target.dim() == 4: # B×1×H×W
318
+ target = target.squeeze(1) # B×H×W
319
+ # target = create_target_from_mask_and_label(target, data_label)
320
+
321
+
322
+ # test_result = debug_target_creation(target, data_label, batch_size=1)
323
+
324
+ # Ensure target i'=s long type
325
+ target = target.long()
326
+
327
+ # Upsample pred if needed
328
+ if score.shape[2:] != target.shape[1:]:
329
+ score = F.interpolate(score, size=target.shape[1:], mode='bilinear', align_corners=False)
330
+
331
+ # Convert single channel to multi-class
332
+ # If score is B×1×H×W, we need to expand it to B×C×H×W
333
+ # if score.shape[1] == 1:
334
+ # # U-Net outputs 1 channel, we need to create num_classes channels
335
+ # # This assumes your U-Net needs modification OR
336
+ # # we convert single channel to multi-class logits
337
+ # raise ValueError(
338
+ # f"U-Net outputs {score.shape[1]} channel but {self.num_classes} classes expected. "
339
+ # "Modify U-Net output layer to have num_classes={} channels".format(self.num_classes)
340
+ # )
341
+
342
+ # Apply softmax to get probabilities
343
+ score_probs = F.softmax(score, dim=1) # B×C×H×W
344
+
345
+ # CrossEntropy loss
346
+ ce_loss = self.ce_loss(score, target)
347
+
348
+ # Valid mask (exclude ignore_index)
349
+ valid_mask = (target != self.ignore_index).float()
350
+
351
+ # Dice loss
352
+ dice_loss = self.dice_loss(score_probs, target, valid_mask)
353
+
354
+ # Combined loss
355
+ total_loss = self.dice_weight * dice_loss + self.ce_weight * ce_loss
356
+
357
+ return total_loss
358
+
359
+
360
+ class DiceEntropyLoss(nn.Module):
361
+ def __init__(self, smooth=1, exponent=2, ignore_label=-1, weight=None): #because padding adds -1 to the targets
362
+ super(DiceEntropyLoss, self).__init__()
363
+ self.ignore_label = ignore_label
364
+ self.smooth = smooth
365
+ self.exponent = exponent
366
+ self.cross_entropy = nn.CrossEntropyLoss(weight=weight,
367
+ ignore_index=ignore_label)
368
+
369
+ def binary_dice_loss(self, pred, target, valid_mask):
370
+ assert pred.shape[0] == target.shape[0]
371
+ pred = pred.reshape(pred.shape[0], -1)
372
+ target = target.reshape(target.shape[0], -1)
373
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
374
+
375
+ num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth
376
+ den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5)
377
+
378
+ dice = num / den
379
+ dice = torch.mean(dice)
380
+ return 1 - dice
381
+
382
+ def forward(self, score, target):
383
+ target = target.squeeze(1).long()
384
+
385
+ target = torch.clamp(target, min=0, max=1)
386
+ ph, pw = score.size(2), score.size(3) # (B,1,224,224)
387
+ h, w = target.size(1), target.size(2) # (B,224,224)
388
+ if ph != h or pw != w:
389
+ score = F.upsample(
390
+ input=score, size=(h, w), mode='bilinear')
391
+
392
+ CE_loss = self.cross_entropy(score, target)
393
+
394
+
395
+ score = F.softmax(score,dim=1)
396
+ num_classes = score.shape[1]
397
+
398
+ one_hot_target = F.one_hot(
399
+ torch.clamp(target.long(), 0, num_classes - 1),
400
+ num_classes=num_classes)
401
+ valid_mask = (target != self.ignore_label).long()
402
+
403
+ # dice_loss = self.binary_dice_loss(
404
+ # score[:, 1],
405
+ # one_hot_target[..., 1],
406
+ # valid_mask)
407
+
408
+ number_of_present_classes = 4
409
+ dice_loss = 0
410
+ for class_id in [1,2,3]:
411
+ if (target == class_id).sum() > 0:
412
+ dice_loss += dice(pred[:, class_id], target_onehot[:, class_id])
413
+ dice_loss /= number_of_present_classes
414
+
415
+ return 0.3*CE_loss + 0.7*dice_loss
416
+
417
+
418
+
419
+
420
+ class FocalLoss(nn.Module):
421
+
422
+ def __init__(self, alpha=0.25, gamma=2., ignore_label=-1): #alpha 0.25, gamma=2.
423
+ super(FocalLoss, self).__init__()
424
+ self.alpha=alpha
425
+ self.gamma= gamma
426
+ self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label, reduction="none")
427
+
428
+ def forward(self, score, target):
429
+ ph, pw = score.size(2), score.size(3)
430
+ h, w = target.size(1), target.size(2)
431
+ if ph != h or pw != w:
432
+ score = F.upsample(
433
+ input=score, size=(h, w), mode='bilinear')
434
+
435
+ ce_loss = self.criterion(score, target)
436
+ pt = torch.exp(-ce_loss)
437
+ f_loss = self.alpha * (1-pt)**self.gamma * ce_loss
438
+ return f_loss.mean()
439
+
440
+