yqcyqc commited on
Commit
9dc8386
·
verified ·
1 Parent(s): ac91f22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -13,7 +13,7 @@ import threading
13
  import concurrent.futures
14
 
15
  # 加载类别名称
16
- with open('output/class_names.pkl', 'rb') as f:
17
  class_names = pickle.load(f)
18
 
19
  # 初始化模型
@@ -23,7 +23,7 @@ model.fc = nn.Sequential(
23
  nn.Dropout(0.2),
24
  nn.Linear(model.fc.in_features, len(class_names))
25
  )
26
- model.load_state_dict(torch.load('output/best_model.pth', map_location=device))
27
  model = model.to(device)
28
  model.eval()
29
 
@@ -137,12 +137,12 @@ def predict_realtime(video_frame, remove_bg):
137
 
138
  def create_interface():
139
  examples = [
140
- "data/r0_0_100.jpg",
141
- "data/r0_18_100.jpg",
142
- "data/9_100.jpg",
143
- "data/127_100.jpg",
144
- "data/5ecc819f1a579f513e0a1500fabb3f0.png",
145
- "data/1105.jpg"
146
  ]
147
 
148
  with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo:
 
13
  import concurrent.futures
14
 
15
  # 加载类别名称
16
+ with open('class_names.pkl', 'rb') as f:
17
  class_names = pickle.load(f)
18
 
19
  # 初始化模型
 
23
  nn.Dropout(0.2),
24
  nn.Linear(model.fc.in_features, len(class_names))
25
  )
26
+ model.load_state_dict(torch.load('best_model.pth', map_location=device))
27
  model = model.to(device)
28
  model.eval()
29
 
 
137
 
138
  def create_interface():
139
  examples = [
140
+ "r0_0_100.jpg",
141
+ "r0_18_100.jpg",
142
+ "9_100.jpg",
143
+ "127_100.jpg",
144
+ "5ecc819f1a579f513e0a1500fabb3f0.png",
145
+ "1105.jpg"
146
  ]
147
 
148
  with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo: