hquan21 commited on
Commit
cf3fc54
·
1 Parent(s): 894f3e1

feat: update new model

Browse files
Files changed (2) hide show
  1. app.py +62 -35
  2. requirements.txt +1 -0
app.py CHANGED
@@ -6,47 +6,74 @@ import json
6
 
7
  print("Đang tải các mô hình AI...")
8
 
9
- # 1. Load mô hình block (Phân loại ảnh vạn vật của Google)
10
- checker = pipeline("image-classification", model="google/vit-base-patch16-224")
11
-
12
- # 2. Load model Định giá
13
- model_id = "hquan21/ai-bike-pricing-up"
14
  processor = BlipProcessor.from_pretrained(model_id)
15
- model = BlipForConditionalGeneration.from_pretrained(model_id)
 
 
 
16
 
17
- print("Tải hình hoàn tất! Sẵn sàng phục vụ.")
18
 
19
  def predict(img):
20
- ket_qua_nhan_dien = checker(img)
21
-
22
- # Lấy 5 nhãn có xác suất cao nhất
23
- top_labels = [res['label'].lower() for res in ket_qua_nhan_dien[:5]]
24
-
25
- # Bộ từ khóa xe đạp trong từ điển ImageNet (viT model)
26
- is_bicycle = any("bicycle" in label or "bike" in label or "velocipede" in label for label in top_labels)
27
-
28
- # Nếu đưa ảnh chó mèo, ô tô, người...
29
- if not is_bicycle:
30
- error_msg = {
31
- "error": "Hình ảnh không chứa xe đạp, hoặc góc chụp không rõ ràng. Vui lòng chụp lại!"
32
- }
33
- # Trả về JSON để FE dễ dàng hiển thị Popup lỗi
34
- return json.dumps(error_msg, ensure_ascii=False)
35
-
36
- inputs = processor(img, return_tensors="pt")
37
-
38
- # Ép AI sinh ra tối đa 150 token (đủ dài cho JSON)
39
- out = model.generate(**inputs, max_new_tokens=150)
40
- result_text = processor.decode(out[0], skip_special_tokens=True)
41
-
42
- return result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  demo = gr.Interface(
45
- fn=predict,
46
- inputs=gr.Image(type="pil"),
47
- outputs="text",
48
- title="🚲 Công cụ Định giá Xe đạp AI",
49
- description="Tải ảnh lên để AI định giá và đưa ra lý do. Cảnh báo: Hệ thống sẽ tự động từ chối nếu bạn tải ảnh chó, mèo, hay đồ vật khác không phải xe đạp!"
 
 
50
  )
51
 
52
  demo.launch()
 
6
 
7
  print("Đang tải các mô hình AI...")
8
 
9
+ checker = pipeline("image-classification", model="google/vit-base-patch16-224")
10
+ model_id = "hquan21/ai-bike-pricing-gold-v2"
 
 
 
11
  processor = BlipProcessor.from_pretrained(model_id)
12
+ model = BlipForConditionalGeneration.from_pretrained(model_id)
13
+ model.eval()
14
+
15
+ print("Tải mô hình hoàn tất!")
16
 
17
+ TU_KHOA_XE = ['bicycle', 'bike', 'velocipede', 'mountain bike', 'tricycle']
18
 
19
  def predict(img):
20
+ # ── Gác cổng ──
21
+ try:
22
+ ket_qua = checker(img)
23
+ top5 = [r['label'].lower() for r in ket_qua[:5]]
24
+ is_bike = any(kw in label for label in top5 for kw in TU_KHOA_XE)
25
+ except Exception as e:
26
+ return json.dumps({
27
+ "success": False,
28
+ "gia_de_xuat": None,
29
+ "ly_do": None,
30
+ "error": f"Lỗi kiểm tra ảnh: {str(e)}"
31
+ }, ensure_ascii=False)
32
+
33
+ if not is_bike:
34
+ return json.dumps({
35
+ "success": False,
36
+ "gia_de_xuat": None,
37
+ "ly_do": None,
38
+ "error": "Không phải xe đạp! Vui lòng tải ảnh toàn thân xe."
39
+ }, ensure_ascii=False)
40
+
41
+ # ── Định giá ──
42
+ try:
43
+ with torch.no_grad():
44
+ inputs = processor(img, return_tensors="pt")
45
+ out = model.generate(**inputs, max_new_tokens=150)
46
+ result = processor.decode(out[0], skip_special_tokens=True)
47
+ except Exception as e:
48
+ return json.dumps({
49
+ "success": False,
50
+ "gia_de_xuat": None,
51
+ "ly_do": None,
52
+ "error": f"Lỗi định giá: {str(e)}"
53
+ }, ensure_ascii=False)
54
+
55
+ # ── Parse output ──
56
+ try:
57
+ data = json.loads(result)
58
+ data["success"] = True
59
+ data["error"] = None
60
+ return json.dumps(data, ensure_ascii=False)
61
+ except json.JSONDecodeError:
62
+ return json.dumps({
63
+ "success": True,
64
+ "gia_de_xuat": None,
65
+ "ly_do": result,
66
+ "error": None
67
+ }, ensure_ascii=False)
68
 
69
  demo = gr.Interface(
70
+ fn = predict,
71
+ inputs = gr.Image(type="pil"),
72
+ outputs = "text",
73
+ title = "Định giá Xe đạp AI",
74
+ description = "Tải ảnh toàn thân xe đạp. Hệ thống tự động từ chối ảnh không phải xe đạp.",
75
+ api_name = "predict",
76
+ examples = [],
77
  )
78
 
79
  demo.launch()
requirements.txt CHANGED
@@ -2,3 +2,4 @@ gradio>=6.9.0
2
  transformers>=4.39.0
3
  torch>=2.2.0
4
  Pillow>=10.0.0
 
 
2
  transformers>=4.39.0
3
  torch>=2.2.0
4
  Pillow>=10.0.0
5
+ accelerate>=0.27.0