Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| import pickle | |
| from resnest.torch import resnest50 | |
| from rembg import remove | |
| from PIL import Image | |
| import io | |
| import json | |
| import time | |
| import threading | |
| import concurrent.futures | |
| # 加载类别名称 | |
| with open('class_names.pkl', 'rb') as f: | |
| class_names = pickle.load(f) | |
| # 初始化模型 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = resnest50(pretrained=False) | |
| model.fc = nn.Sequential( | |
| nn.Dropout(0.2), | |
| nn.Linear(model.fc.in_features, len(class_names)) | |
| ) | |
| model.load_state_dict(torch.load('best_model.pth', map_location=device)) | |
| model = model.to(device) | |
| model.eval() | |
| # 预处理流程 | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((100, 100)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # 创建线程池 | |
| executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) | |
| class RealtimeState: | |
| def __init__(self): | |
| self.last_result = None | |
| self.last_update_time = 0 | |
| self.is_processing = False | |
| self.lock = threading.Lock() | |
| realtime_state = RealtimeState() | |
| def remove_background(img): | |
| """使用rembg去除背景并添加白色背景""" | |
| img_byte_arr = io.BytesIO() | |
| img.save(img_byte_arr, format='PNG') | |
| img_bytes = img_byte_arr.getvalue() | |
| removed_bg_bytes = remove(img_bytes) | |
| removed_bg_img = Image.open(io.BytesIO(removed_bg_bytes)).convert('RGBA') | |
| white_bg = Image.new('RGBA', removed_bg_img.size, (255, 255, 255, 255)) | |
| combined = Image.alpha_composite(white_bg, removed_bg_img) | |
| return combined.convert('RGB') | |
| def predict_image(img, remove_bg=False): | |
| """分类预测主函数""" | |
| if remove_bg: | |
| processed_img = remove_background(img) | |
| else: | |
| processed_img = img.convert('RGB') | |
| input_tensor = preprocess(processed_img) | |
| input_batch = input_tensor.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(input_batch) | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| top3_probs, top3_indices = torch.topk(probabilities, 3) | |
| results = { | |
| class_names[i]: round(p.item(), 4) | |
| for p, i in zip(top3_probs, top3_indices) | |
| } | |
| best_class = class_names[top3_indices[0]] | |
| best_conf = top3_probs[0].item() * 100 | |
| with open('output/prediction_results.txt', 'a') as f: | |
| f.write(f"Remove BG: {remove_bg}\n") | |
| f.write(f"Predicted: {best_class} ({best_conf:.2f}%)\n") | |
| f.write(f"Top 3: {results}\n\n") | |
| return None, processed_img, best_class, f"{best_conf:.2f}%", results | |
| def predict_realtime(video_frame, remove_bg): | |
| """实时预测主函数,结果保留2秒""" | |
| global realtime_state | |
| if video_frame is None: | |
| return None, None, None, None, None | |
| current_time = time.time() | |
| # 检查是否有未过期的结果 | |
| with realtime_state.lock: | |
| if realtime_state.last_result and current_time - realtime_state.last_update_time < 2: | |
| return realtime_state.last_result | |
| # 如果正在处理中,返回None | |
| if realtime_state.is_processing: | |
| return None, None, None, None, None | |
| # 标记为正在处理 | |
| realtime_state.is_processing = True | |
| # 异步处理帧 | |
| def process_frame(): | |
| try: | |
| result = predict_image(video_frame, remove_bg) | |
| with realtime_state.lock: | |
| realtime_state.last_result = result | |
| realtime_state.last_update_time = time.time() | |
| realtime_state.is_processing = False | |
| except Exception as e: | |
| print(f"处理帧时出错: {e}") | |
| with realtime_state.lock: | |
| realtime_state.is_processing = False | |
| # 提交到线程池处理 | |
| executor.submit(process_frame) | |
| return None, None, None, None, None | |
| def create_interface(): | |
| examples = [ | |
| "r0_0_100.jpg", | |
| "r0_18_100.jpg", | |
| "9_100.jpg", | |
| "127_100.jpg", | |
| "5ecc819f1a579f513e0a1500fabb3f0.png", | |
| "1105.jpg" | |
| ] | |
| with gr.Blocks(title="Fruit Classification", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("""# 🍎 智能水果识别系统""") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Group(): | |
| gr.Markdown("## ⚙️ 处理模式选择") | |
| with gr.Row(): | |
| bg_removal = gr.Checkbox(label="背景去除", value=False, interactive=True) | |
| with gr.Column(): | |
| original_image = gr.Image(label="📤 上传图片", type="pil") | |
| gr.Examples(examples=examples, inputs=original_image) | |
| submit_btn = gr.Button("🚀 开始识别", variant="primary") | |
| gr.Markdown("""## ⚡ 实时识别""") | |
| camera = gr.Image(label="📷 摄像头捕获", type="pil", streaming=True) | |
| with gr.Column(): | |
| prediction_id_output = gr.Textbox(label="🔍 预测ID", interactive=False, visible=False) | |
| processed_image = gr.Image(label="🖼️ 处理后图片", interactive=False) | |
| best_pred = gr.Textbox(label="🔍 识别结果") | |
| confidence = gr.Textbox(label="📊 置信度") | |
| full_results = gr.Label(label="🏆 Top 3 可能结果", num_top_classes=3) | |
| submit_btn.click( | |
| fn=predict_image, | |
| inputs=[original_image, bg_removal], | |
| outputs=[prediction_id_output, processed_image, best_pred, confidence, full_results] | |
| ) | |
| camera.stream( | |
| fn=predict_realtime, | |
| inputs=[camera, bg_removal], | |
| outputs=[prediction_id_output, processed_image, best_pred, confidence, full_results] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch(share=True) |