Ha Trong Nguyen commited on
Commit
efe5fd6
·
1 Parent(s): e800115

feat: add batch inference testing script and update documentation

Browse files
Files changed (2) hide show
  1. README.md +30 -2
  2. test_batch_inference.py +138 -0
README.md CHANGED
@@ -10,5 +10,33 @@ license: mit
10
 
11
  # TrafficFlow API Backend
12
 
13
- This is the FastAPI backend and ZIP inference model for the TrafficFlow application.
14
- It exposes endpoints for getting HCMC camera data, predicting vehicle counts, and proxying camera images.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # TrafficFlow API Backend
12
 
13
+ Đây dịch vụ Backend FastAPI hình AI Zero-shot Image Prior (ZIP) cho dự án TrafficFlow.
14
+ Hệ thống cung cấp các API để truy xuất camera giao thông TP.HCM, dự đoán lưu lượng xe cộ và proxy hình ảnh camera.
15
+
16
+ ---
17
+
18
+ ## 🚀 Hướng dẫn Đánh giá Mô hình (Model Evaluation)
19
+
20
+ Dự án cung cấp sẵn kịch bản kiểm thử tự động `test_batch_inference.py` để sinh viên/nhóm nghiên cứu có thể chạy hàng loạt ảnh và lấy kết quả thống kê.
21
+
22
+ ### 1. Cách chạy kịch bản kiểm thử (Script)
23
+
24
+ 1. Cài đặt thư viện: `pip install httpx`
25
+ 2. Tạo thư mục `test_images/` nằm cùng cấp với file script và chép các ảnh camera giao thông cần kiểm thử vào (hỗ trợ `.jpg`, `.png`).
26
+ 3. Chạy lệnh: `python test_batch_inference.py`
27
+ 4. Code sẽ tự động gửi ảnh lên Hugging Face Endpoint và xuất ra file báo cáo `evaluation_results.csv`.
28
+
29
+ _(Lưu ý: Mặc định script sẽ tạm dừng 1 giây giữa các ảnh để tránh làm quá tải (spam) Endpoint)._
30
+
31
+ ### 2. Giải thích các Thông số Đầu ra (Output Parameters)
32
+
33
+ Kết quả trong file `.csv` chứa các tham số quan trọng sau:
34
+
35
+ - **`total_count`**: Tổng số phương tiện đếm được trong ảnh.
36
+ - **`car_count` / `motorbike_count`**: Số lượng dự đoán bóc tách riêng từng loại xe (Dựa trên tỷ lệ ước tính của hệ thống).
37
+ - **`density_level`**: Phân loại mức độ kẹt xe AI đánh giá:
38
+ - `low` (Thông thoáng): Mật độ xe thưa thớt, đường trống.
39
+ - `moderate` (Đông vừa): Xe bắt đầu đông nhưng vẫn di chuyển ổn định.
40
+ - `heavy` (Kẹt xe): Lượng xe rất đông, có dấu hiệu ùn ứ.
41
+ - `severe` (Kẹt cứng): Lòng đường đặc kín xe, không thể di chuyển.
42
+ - **`latency_seconds`**: Thời gian phản hồi (tính bằng giây). Bao gồm thời gian mạng truyền tải (Network latency) + Thời gian mô hình AI phân tích (Inference time).
test_batch_inference.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import glob
4
+ import time
5
+ import csv
6
+ from pathlib import Path
7
+ import httpx
8
+ import asyncio
9
+
10
+ # Cấu hình
11
+ API_URL = "https://htrnguyen-trafficflow-api.hf.space/api/predict"
12
+ INPUT_FOLDER = "./test_images"
13
+ OUTPUT_CSV = "evaluation_results.csv"
14
+ SLEEP_SECONDS = 1.5
15
+
16
+
17
+ async def predict_image(client: httpx.AsyncClient, image_path: str):
18
+ filename = os.path.basename(image_path)
19
+ try:
20
+ with open(image_path, "rb") as f:
21
+ files = {"file": (filename, f, "image/jpeg")}
22
+
23
+ start_time = time.time()
24
+ response = await client.post(API_URL, files=files, timeout=60.0)
25
+ latency = time.time() - start_time
26
+
27
+ if response.status_code == 200:
28
+ data = response.json()
29
+ pred = data.get("prediction", {})
30
+ return {
31
+ "filename": filename,
32
+ "status": "success",
33
+ "total_count": pred.get("total_count", 0),
34
+ "car_count": pred.get("car_count", 0),
35
+ "motorbike_count": pred.get("motorbike_count", 0),
36
+ "person_count": pred.get("person_count", 0),
37
+ "density_level": pred.get("density_level", "unknown"),
38
+ "latency_seconds": round(latency, 2),
39
+ "error": "",
40
+ }
41
+ else:
42
+ return {
43
+ "filename": filename,
44
+ "status": "failed",
45
+ "total_count": 0,
46
+ "car_count": 0,
47
+ "motorbike_count": 0,
48
+ "person_count": 0,
49
+ "density_level": "error",
50
+ "latency_seconds": round(latency, 2),
51
+ "error": f"HTTP {response.status_code}: {response.text}",
52
+ }
53
+ except Exception as e:
54
+ return {
55
+ "filename": filename,
56
+ "status": "error",
57
+ "total_count": 0,
58
+ "car_count": 0,
59
+ "motorbike_count": 0,
60
+ "person_count": 0,
61
+ "density_level": "error",
62
+ "latency_seconds": 0,
63
+ "error": str(e),
64
+ }
65
+
66
+
67
+ async def main():
68
+ if not os.path.exists(INPUT_FOLDER):
69
+ print(f"[ERROR] Không tìm thấy thư mục: {INPUT_FOLDER}")
70
+ print("[INFO] Vui lòng tạo thư mục này và chép ảnh (jpg, png) vào để bắt đầu.")
71
+ os.makedirs(INPUT_FOLDER, exist_ok=True)
72
+ return
73
+
74
+ image_extensions = ["*.jpg", "*.jpeg", "*.png"]
75
+ image_paths = []
76
+ for ext in image_extensions:
77
+ image_paths.extend(glob.glob(os.path.join(INPUT_FOLDER, ext)))
78
+ image_paths.extend(glob.glob(os.path.join(INPUT_FOLDER, ext.upper())))
79
+
80
+ if not image_paths:
81
+ print(f"[WARNING] Không có ảnh nào trong thư mục {INPUT_FOLDER}.")
82
+ return
83
+
84
+ print(f"[INFO] Bắt đầu đánh giá {len(image_paths)} ảnh qua API: {API_URL}")
85
+ results = []
86
+
87
+ async with httpx.AsyncClient() as client:
88
+ for i, path in enumerate(image_paths, 1):
89
+ print(
90
+ f"[{i}/{len(image_paths)}] Đang xử lý: {os.path.basename(path)} ... ",
91
+ end="",
92
+ flush=True,
93
+ )
94
+ res = await predict_image(client, path)
95
+ results.append(res)
96
+
97
+ if res["status"] == "success":
98
+ print(
99
+ f"[SUCCESS] {res['total_count']} phương tiện ({res['latency_seconds']}s)"
100
+ )
101
+ else:
102
+ print(f"[ERROR] {res['error']}")
103
+
104
+ if i < len(image_paths):
105
+ print(f"[INFO] Tạm dừng {SLEEP_SECONDS}s trước khi gửi tiếp...")
106
+ await asyncio.sleep(SLEEP_SECONDS)
107
+
108
+ fieldnames = [
109
+ "filename",
110
+ "status",
111
+ "total_count",
112
+ "car_count",
113
+ "motorbike_count",
114
+ "person_count",
115
+ "density_level",
116
+ "latency_seconds",
117
+ "error",
118
+ ]
119
+
120
+ with open(OUTPUT_CSV, mode="w", newline="", encoding="utf-8-sig") as f:
121
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
122
+ writer.writeheader()
123
+ writer.writerows(results)
124
+
125
+ print(f"\n[INFO] Hoàn tất quá trình đánh giá. Báo cáo lưu tại: {OUTPUT_CSV}")
126
+
127
+ successful_runs = [r for r in results if r["status"] == "success"]
128
+ if successful_runs:
129
+ avg_latency = sum(r["latency_seconds"] for r in successful_runs) / len(
130
+ successful_runs
131
+ )
132
+ print(f"[REPORT] Thống kê hiệu năng:")
133
+ print(f" - Số lượng thành công: {len(successful_runs)}/{len(results)}")
134
+ print(f" - Thời gian phản hồi trung bình: {avg_latency:.2f} giây/ảnh")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ asyncio.run(main())