dthh commited on
Commit
e7840f6
·
verified ·
1 Parent(s): 1e20859

Upload prediction_example.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. prediction_example.py +52 -0
prediction_example.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from PIL import Image
4
+ import pandas as pd
5
+
6
+ import Models
7
+
8
+ config = {
9
+ "model_root": "models",
10
+ "hf_model_repo": "SurfaceAI/models",
11
+ "models": {
12
+ "surface_type": "v1/surface_type_v1.pt",
13
+ "surface_quality": {
14
+ "asphalt": "v1/surface_quality_asphalt_v1.pt",
15
+ "concrete": "v1/surface_quality_concrete_v1.pt",
16
+ "paving_stones": "v1/surface_quality_paving_stones_v1.pt",
17
+ "sett": "v1/surface_quality_sett_v1.pt",
18
+ "unpaved": "v1/surface_quality_unpaved_v1.pt"
19
+ },
20
+ "road_type": "v1/road_type_v1.pt"
21
+ },
22
+ "gpu_kernel": 0,
23
+ "transform_surface": {
24
+ "resize": 384,
25
+ "crop": "lower_middle_half"
26
+ },
27
+ "transform_road_type": {
28
+ "resize": 384,
29
+ "crop": "lower_half"
30
+ },
31
+ }
32
+
33
+ root_path = Path(os.path.abspath(__file__)).parent
34
+
35
+ image_ids = [
36
+ "1351262795801113",
37
+ "153111940043147",
38
+ "1424818291203908",
39
+ ]
40
+
41
+ image_data = []
42
+ for id in image_ids:
43
+ path = root_path / "example_images" / f"{id}.jpg"
44
+ try:
45
+ image_data.append(Image.open(path))
46
+ except Exception as e:
47
+ print(f'{e}: Not found or corrupted image: {path}')
48
+ continue
49
+
50
+ md = Models.ModelInterface(config=config)
51
+ df = md.batch_classifications(image_data, image_ids)
52
+ df.to_csv("example_prediction.csv")