dthh commited on
Commit
f000f4b
·
verified ·
1 Parent(s): dd9d3f9

Upload prediction_example.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. prediction_example.py +54 -0
prediction_example.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from PIL import Image
4
+ import logging
5
+
6
+ import Models
7
+
8
+ config = {
9
+ "model_root": "models",
10
+ "hf_model_repo": "SurfaceAI/models",
11
+ "models": {
12
+ "surface_type": "v1/surface_type_v1.pt",
13
+ "surface_quality": {
14
+ "asphalt": "v1/surface_quality_asphalt_v1.pt",
15
+ "concrete": "v1/surface_quality_concrete_v1.pt",
16
+ "paving_stones": "v1/surface_quality_paving_stones_v1.pt",
17
+ "sett": "v1/surface_quality_sett_v1.pt",
18
+ "unpaved": "v1/surface_quality_unpaved_v1.pt"
19
+ },
20
+ "road_type": "v1/road_type_v1.pt"
21
+ },
22
+ "gpu_kernel": 0,
23
+ "transform_surface": {
24
+ "resize": 384,
25
+ "crop": "lower_middle_half"
26
+ },
27
+ "transform_road_type": {
28
+ "resize": 384,
29
+ "crop": "lower_half"
30
+ },
31
+ }
32
+
33
+ root_path = Path(os.path.abspath(__file__)).parent
34
+
35
+ image_ids = [
36
+ # "IMG_20210221_135447",
37
+ "IMG_20210226_172956",
38
+ # "IMG_20230130_162826",
39
+ ]
40
+
41
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
42
+
43
+ image_data = []
44
+ for id in image_ids:
45
+ path = root_path / "example_images" / f"{id}.jpg"
46
+ try:
47
+ image_data.append(Image.open(path))
48
+ except Exception as e:
49
+ logging.warning(f'{e}: Not found or corrupted image: {path}')
50
+
51
+ md = Models.ModelInterface(config=config)
52
+ results = md.batch_classifications(image_data, image_ids)
53
+ for result in results:
54
+ print(result)