yqcyqc commited on
Commit
1cfab24
·
verified ·
1 Parent(s): cea9b8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -59
app.py CHANGED
@@ -7,10 +7,13 @@ from resnest.torch import resnest50
7
  from rembg import remove
8
  from PIL import Image
9
  import io
10
- import requests
 
 
 
11
 
12
  # 加载类别名称
13
- with open('class_names.pkl', 'rb') as f:
14
  class_names = pickle.load(f)
15
 
16
  # 初始化模型
@@ -20,7 +23,7 @@ model.fc = nn.Sequential(
20
  nn.Dropout(0.2),
21
  nn.Linear(model.fc.in_features, len(class_names))
22
  )
23
- model.load_state_dict(torch.load('best_model.pth', map_location=device))
24
  model = model.to(device)
25
  model.eval()
26
 
@@ -31,21 +34,30 @@ preprocess = transforms.Compose([
31
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
32
  ])
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def remove_background(img):
36
  """使用rembg去除背景并添加白色背景"""
37
- # 转换图像为字节流
38
  img_byte_arr = io.BytesIO()
39
  img.save(img_byte_arr, format='PNG')
40
  img_bytes = img_byte_arr.getvalue()
41
 
42
- # 去除背景
43
  removed_bg_bytes = remove(img_bytes)
44
-
45
- # 转换为PIL图像并处理透明度
46
  removed_bg_img = Image.open(io.BytesIO(removed_bg_bytes)).convert('RGBA')
47
 
48
- # 创建白色背景
49
  white_bg = Image.new('RGBA', removed_bg_img.size, (255, 255, 255, 255))
50
  combined = Image.alpha_composite(white_bg, removed_bg_img)
51
  return combined.convert('RGB')
@@ -53,17 +65,14 @@ def remove_background(img):
53
 
54
  def predict_image(img, remove_bg=False):
55
  """分类预测主函数"""
56
- # 根据选择处理图像
57
  if remove_bg:
58
  processed_img = remove_background(img)
59
  else:
60
- processed_img = img.convert('RGB') # 确保为RGB格式
61
 
62
- # 预处理
63
  input_tensor = preprocess(processed_img)
64
  input_batch = input_tensor.unsqueeze(0).to(device)
65
 
66
- # 预测
67
  with torch.no_grad():
68
  output = model(input_batch)
69
 
@@ -71,75 +80,91 @@ def predict_image(img, remove_bg=False):
71
  top3_probs, top3_indices = torch.topk(probabilities, 3)
72
 
73
  results = {
74
- class_names[i]: p.item()
75
  for p, i in zip(top3_probs, top3_indices)
76
  }
77
 
78
- # 记录结果
79
  best_class = class_names[top3_indices[0]]
80
  best_conf = top3_probs[0].item() * 100
81
-
82
-
83
- # 新增:调用本地API保存结果
84
- api_url = "http://10.230.23.58:8806/save_result" # 替换为你的本地IP
85
- payload = {
86
- "filename": "uploaded_image.jpg", # 可改为实际文件名
87
- "class": best_class,
88
- "confidence": f"{best_conf:.2f}%"
89
- }
90
- try:
91
- requests.post(api_url, json=payload, timeout=3)
92
- except Exception as e:
93
- print(f"保存到数据库失败: {e}")
94
 
95
- return processed_img, best_class, f"{best_conf:.2f}%", results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  def create_interface():
99
  examples = [
100
- "r0_0_100.jpg",
101
- "r0_18_100.jpg",
102
- "9_100.jpg",
103
- "100_100.jpg",
104
- "1105.jpg",
105
- "5ecc819f1a579f513e0a1500fabb3f0.png"
106
  ]
107
 
108
  with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
109
- gr.Markdown("""
110
- # 🍎 智能水果识别系统
111
- """)
112
 
113
- # 新增:模式选择卡片(视觉强化)
114
  with gr.Row():
115
  with gr.Column(scale=3):
116
  with gr.Group():
117
- gr.Markdown("### ⚙️ 处理模式选择")
118
  with gr.Row():
119
- bg_removal = gr.Checkbox(
120
- label="背景去除",
121
- value=False,
122
- interactive=True
123
- )
124
 
125
- # 主操作区域
126
- with gr.Row():
127
- with gr.Column():
128
- original_image = gr.Image(label="📤 上传图片", type="pil")
129
- gr.Examples(examples=examples, inputs=original_image)
130
  submit_btn = gr.Button("🚀 开始识别", variant="primary")
131
 
132
- # 添加模式说明提示
133
- gr.Markdown("""
134
- <div style="background: #f3f4f6; padding: 15px; border-radius: 8px; margin-top: 10px">
135
- <b>💡 使用建议:</b><br>
136
- • 上传图片:选择一张图片,点击'开始识别'按钮<br>
137
- • 勾选背景去除:适合杂乱背景的图片(识别更准确)<br>
138
- • 不勾选:适合纯色背景的图片(速度更快)
139
- </div>
140
- """)
141
 
142
  with gr.Column():
 
143
  processed_image = gr.Image(label="🖼️ 处理后图片", interactive=False)
144
  best_pred = gr.Textbox(label="🔍 识别结果")
145
  confidence = gr.Textbox(label="📊 置信度")
@@ -148,7 +173,13 @@ def create_interface():
148
  submit_btn.click(
149
  fn=predict_image,
150
  inputs=[original_image, bg_removal],
151
- outputs=[processed_image, best_pred, confidence, full_results]
 
 
 
 
 
 
152
  )
153
 
154
  return demo
 
7
  from rembg import remove
8
  from PIL import Image
9
  import io
10
+ import json
11
+ import time
12
+ import threading
13
+ import concurrent.futures
14
 
15
  # 加载类别名称
16
+ with open('output/class_names.pkl', 'rb') as f:
17
  class_names = pickle.load(f)
18
 
19
  # 初始化模型
 
23
  nn.Dropout(0.2),
24
  nn.Linear(model.fc.in_features, len(class_names))
25
  )
26
+ model.load_state_dict(torch.load('output/best_model.pth', map_location=device))
27
  model = model.to(device)
28
  model.eval()
29
 
 
34
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
35
  ])
36
 
37
+ # 创建线程池
38
+ executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
39
+
40
+
41
+ class RealtimeState:
42
+ def __init__(self):
43
+ self.last_result = None
44
+ self.last_update_time = 0
45
+ self.is_processing = False
46
+ self.lock = threading.Lock()
47
+
48
+
49
+ realtime_state = RealtimeState()
50
+
51
 
52
  def remove_background(img):
53
  """使用rembg去除背景并添加白色背景"""
 
54
  img_byte_arr = io.BytesIO()
55
  img.save(img_byte_arr, format='PNG')
56
  img_bytes = img_byte_arr.getvalue()
57
 
 
58
  removed_bg_bytes = remove(img_bytes)
 
 
59
  removed_bg_img = Image.open(io.BytesIO(removed_bg_bytes)).convert('RGBA')
60
 
 
61
  white_bg = Image.new('RGBA', removed_bg_img.size, (255, 255, 255, 255))
62
  combined = Image.alpha_composite(white_bg, removed_bg_img)
63
  return combined.convert('RGB')
 
65
 
66
  def predict_image(img, remove_bg=False):
67
  """分类预测主函数"""
 
68
  if remove_bg:
69
  processed_img = remove_background(img)
70
  else:
71
+ processed_img = img.convert('RGB')
72
 
 
73
  input_tensor = preprocess(processed_img)
74
  input_batch = input_tensor.unsqueeze(0).to(device)
75
 
 
76
  with torch.no_grad():
77
  output = model(input_batch)
78
 
 
80
  top3_probs, top3_indices = torch.topk(probabilities, 3)
81
 
82
  results = {
83
+ class_names[i]: round(p.item(), 4)
84
  for p, i in zip(top3_probs, top3_indices)
85
  }
86
 
 
87
  best_class = class_names[top3_indices[0]]
88
  best_conf = top3_probs[0].item() * 100
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ with open('output/prediction_results.txt', 'a') as f:
91
+ f.write(f"Remove BG: {remove_bg}\n")
92
+ f.write(f"Predicted: {best_class} ({best_conf:.2f}%)\n")
93
+ f.write(f"Top 3: {results}\n\n")
94
+
95
+ return None, processed_img, best_class, f"{best_conf:.2f}%", results
96
+
97
+
98
+ def predict_realtime(video_frame, remove_bg):
99
+ """实时预测主函数,结果保留2秒"""
100
+ global realtime_state
101
+
102
+ if video_frame is None:
103
+ return None, None, None, None, None
104
+
105
+ current_time = time.time()
106
+
107
+ # 检查是否有未过期的结果
108
+ with realtime_state.lock:
109
+ if realtime_state.last_result and current_time - realtime_state.last_update_time < 2:
110
+ return realtime_state.last_result
111
+
112
+ # 如果正在处理中,返回None
113
+ if realtime_state.is_processing:
114
+ return None, None, None, None, None
115
+
116
+ # 标记为正在处理
117
+ realtime_state.is_processing = True
118
+
119
+ # 异步处理帧
120
+ def process_frame():
121
+ try:
122
+ result = predict_image(video_frame, remove_bg)
123
+ with realtime_state.lock:
124
+ realtime_state.last_result = result
125
+ realtime_state.last_update_time = time.time()
126
+ realtime_state.is_processing = False
127
+ except Exception as e:
128
+ print(f"处理帧时出错: {e}")
129
+ with realtime_state.lock:
130
+ realtime_state.is_processing = False
131
+
132
+ # 提交到线程池处理
133
+ executor.submit(process_frame)
134
+
135
+ return None, None, None, None, None
136
 
137
 
138
  def create_interface():
139
  examples = [
140
+ "data/r0_0_100.jpg",
141
+ "data/r0_18_100.jpg",
142
+ "data/9_100.jpg",
143
+ "data/127_100.jpg",
144
+ "data/5ecc819f1a579f513e0a1500fabb3f0.png",
145
+ "data/1105.jpg"
146
  ]
147
 
148
  with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
149
+ gr.Markdown("""# 🍎 智能水果识别系统""")
 
 
150
 
 
151
  with gr.Row():
152
  with gr.Column(scale=3):
153
  with gr.Group():
154
+ gr.Markdown("## ⚙️ 处理模式选择")
155
  with gr.Row():
156
+ bg_removal = gr.Checkbox(label="背景去除", value=False, interactive=True)
157
+ with gr.Column():
158
+ original_image = gr.Image(label="📤 上传图片", type="pil")
159
+ gr.Examples(examples=examples, inputs=original_image)
 
160
 
 
 
 
 
 
161
  submit_btn = gr.Button("🚀 开始识别", variant="primary")
162
 
163
+ gr.Markdown("""## ⚡ 实时识别""")
164
+ camera = gr.Image(label="📷 摄像头捕获", type="pil", streaming=True)
 
 
 
 
 
 
 
165
 
166
  with gr.Column():
167
+ prediction_id_output = gr.Textbox(label="🔍 预测ID", interactive=False, visible=False)
168
  processed_image = gr.Image(label="🖼️ 处理后图片", interactive=False)
169
  best_pred = gr.Textbox(label="🔍 识别结果")
170
  confidence = gr.Textbox(label="📊 置信度")
 
173
  submit_btn.click(
174
  fn=predict_image,
175
  inputs=[original_image, bg_removal],
176
+ outputs=[prediction_id_output, processed_image, best_pred, confidence, full_results]
177
+ )
178
+
179
+ camera.stream(
180
+ fn=predict_realtime,
181
+ inputs=[camera, bg_removal],
182
+ outputs=[prediction_id_output, processed_image, best_pred, confidence, full_results]
183
  )
184
 
185
  return demo