v1 fix
Browse files- inference.py +2 -2
- model.pt +1 -1
- train.py +6 -6
inference.py
CHANGED
|
@@ -54,8 +54,8 @@ class CompressionArtifactPredictor:
|
|
| 54 |
self.quality_ranges = {
|
| 55 |
'jpeg': (0, 100),
|
| 56 |
'webp': (0, 100),
|
| 57 |
-
'avif': (0,
|
| 58 |
-
'jxl': (0,
|
| 59 |
}
|
| 60 |
|
| 61 |
def predict(self, image: Image.Image) -> Dict[str, Dict[str, float]]:
|
|
|
|
| 54 |
self.quality_ranges = {
|
| 55 |
'jpeg': (0, 100),
|
| 56 |
'webp': (0, 100),
|
| 57 |
+
'avif': (0, 100),
|
| 58 |
+
'jxl': (0, 100)
|
| 59 |
}
|
| 60 |
|
| 61 |
def predict(self, image: Image.Image) -> Dict[str, Dict[str, float]]:
|
model.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 9386549
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f260ddcbab4ecab0db436a2c72e146932bb10011b4ba841755fc63f49151979
|
| 3 |
size 9386549
|
train.py
CHANGED
|
@@ -27,8 +27,8 @@ class Config:
|
|
| 27 |
QUALITY_RANGES = {
|
| 28 |
'jpeg': (0, 100),
|
| 29 |
'webp': (0, 100),
|
| 30 |
-
'avif': (0,
|
| 31 |
-
'jxl': (0,
|
| 32 |
}
|
| 33 |
|
| 34 |
# Training
|
|
@@ -57,9 +57,9 @@ def ensure_dir(path: str):
|
|
| 57 |
def quality_to_normalized(quality: float, type: str) -> float:
|
| 58 |
"""Normalize JPEG quality [0,100] to [0,1]"""
|
| 59 |
if type == "avif":
|
| 60 |
-
return quality /
|
| 61 |
if type == "jxl":
|
| 62 |
-
return quality /
|
| 63 |
return quality / 100.0
|
| 64 |
|
| 65 |
|
|
@@ -125,14 +125,14 @@ class CompressionDataset(Dataset):
|
|
| 125 |
targets.append(quality_to_normalized(quality, "webp"))
|
| 126 |
formats.append(Config.COMPRESSION_FORMATS.index("webp"))
|
| 127 |
|
| 128 |
-
quality = random.randint(0,
|
| 129 |
compressed = compress_image(image.copy(), "avif", quality)
|
| 130 |
tensor = transforms.ToTensor()(compressed)
|
| 131 |
images.append(tensor)
|
| 132 |
targets.append(quality_to_normalized(quality, "avif"))
|
| 133 |
formats.append(Config.COMPRESSION_FORMATS.index("avif"))
|
| 134 |
|
| 135 |
-
quality = random.randint(0,
|
| 136 |
compressed = compress_image(image.copy(), "jxl", quality)
|
| 137 |
tensor = transforms.ToTensor()(compressed)
|
| 138 |
images.append(tensor)
|
|
|
|
| 27 |
QUALITY_RANGES = {
|
| 28 |
'jpeg': (0, 100),
|
| 29 |
'webp': (0, 100),
|
| 30 |
+
'avif': (0, 100),
|
| 31 |
+
'jxl': (0, 100)
|
| 32 |
}
|
| 33 |
|
| 34 |
# Training
|
|
|
|
| 57 |
def quality_to_normalized(quality: float, type: str) -> float:
|
| 58 |
"""Normalize JPEG quality [0,100] to [0,1]"""
|
| 59 |
if type == "avif":
|
| 60 |
+
return quality / 100
|
| 61 |
if type == "jxl":
|
| 62 |
+
return quality / 100
|
| 63 |
return quality / 100.0
|
| 64 |
|
| 65 |
|
|
|
|
| 125 |
targets.append(quality_to_normalized(quality, "webp"))
|
| 126 |
formats.append(Config.COMPRESSION_FORMATS.index("webp"))
|
| 127 |
|
| 128 |
+
quality = random.randint(0, 100)
|
| 129 |
compressed = compress_image(image.copy(), "avif", quality)
|
| 130 |
tensor = transforms.ToTensor()(compressed)
|
| 131 |
images.append(tensor)
|
| 132 |
targets.append(quality_to_normalized(quality, "avif"))
|
| 133 |
formats.append(Config.COMPRESSION_FORMATS.index("avif"))
|
| 134 |
|
| 135 |
+
quality = random.randint(0, 100)
|
| 136 |
compressed = compress_image(image.copy(), "jxl", quality)
|
| 137 |
tensor = transforms.ToTensor()(compressed)
|
| 138 |
images.append(tensor)
|