Spaces:
Sleeping
Sleeping
| import os | |
| from transformers import AutoFeatureExtractor, AutoModelForObjectDetection | |
| import torch | |
| from huggingface_hub import login | |
| import logging | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| from PIL import Image | |
| import numpy as np | |
| from config import MODEL_NAME | |
| # 配置日志记录 | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class RadarDetectionModel: | |
| def __init__(self, model_name=None, use_auth_token=None): | |
| """ | |
| 初始化雷达检测模型。 | |
| Args: | |
| model_name (str): 要加载的模型名称或路径 | |
| use_auth_token (str, optional): 用于访问受限模型的Hugging Face令牌 | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"使用设备: {self.device}") | |
| self.model_name = model_name if model_name else MODEL_NAME | |
| logger.info(f"模型名称: {self.model_name}") | |
| self.use_auth_token = use_auth_token or os.environ.get("HF_TOKEN") | |
| if self.use_auth_token: | |
| logger.info("已提供Hugging Face令牌") | |
| else: | |
| logger.warning("未提供Hugging Face令牌,可能无法访问受限模型") | |
| self.processor = None | |
| self.model = None | |
| # 加载模型和处理器 | |
| logger.info("开始加载模型和处理器...") | |
| self._load_model() | |
| def _load_model(self): | |
| """加载模型和处理器,并监控内存使用情况""" | |
| try: | |
| logger.info(f"正在从{self.model_name}加载处理器") | |
| start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None | |
| end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None | |
| if start_time: | |
| start_time.record() | |
| if self.use_auth_token: | |
| # 如果提供了令牌,登录到Hugging Face Hub | |
| logger.info("使用令牌登录到Hugging Face Hub") | |
| login(token=self.use_auth_token) | |
| self.processor = AutoProcessor.from_pretrained(self.model_name, use_auth_token=self.use_auth_token) | |
| else: | |
| self.processor = AutoProcessor.from_pretrained(self.model_name) | |
| if end_time: | |
| end_time.record() | |
| torch.cuda.synchronize() | |
| logger.info(f"处理器加载时间: {start_time.elapsed_time(end_time):.2f}毫秒") | |
| logger.info(f"正在从{self.model_name}加载模型,使用8位量化以减少内存使用") | |
| if start_time: | |
| start_time.record() | |
| # 使用8位量化以减少内存使用 | |
| if self.use_auth_token: | |
| self.model = AutoModelForVision2Seq.from_pretrained( | |
| self.model_name, | |
| use_auth_token=self.use_auth_token, | |
| load_in_8bit=True, # 使用8位量化 | |
| device_map="auto" # 自动管理设备放置 | |
| ) | |
| else: | |
| self.model = AutoModelForVision2Seq.from_pretrained( | |
| self.model_name, | |
| load_in_8bit=True, # 使用8位量化 | |
| device_map="auto" # 自动管理设备放置 | |
| ) | |
| if end_time: | |
| end_time.record() | |
| torch.cuda.synchronize() | |
| logger.info(f"模型加载时间: {start_time.elapsed_time(end_time):.2f}毫秒") | |
| logger.info(f"模型加载成功") | |
| # 使用device_map="auto"时无需手动移动到设备 | |
| self.model.eval() | |
| # 记录模型信息 | |
| param_count = sum(p.numel() for p in self.model.parameters()) | |
| logger.info(f"模型参数数量: {param_count:,}") | |
| if torch.cuda.is_available(): | |
| memory_allocated = torch.cuda.memory_allocated() / (1024 * 1024) | |
| memory_reserved = torch.cuda.memory_reserved() / (1024 * 1024) | |
| logger.info(f"GPU内存分配: {memory_allocated:.2f}MB") | |
| logger.info(f"GPU内存保留: {memory_reserved:.2f}MB") | |
| except Exception as e: | |
| logger.error(f"加载模型时出错: {str(e)}") | |
| raise | |
| def detect(self, image): | |
| """ | |
| 检测雷达图像中的对象。 | |
| Args: | |
| image (PIL.Image): 要分析的雷达图像 | |
| Returns: | |
| dict: 检测结果,包括边界框、分数和标签 | |
| """ | |
| try: | |
| if self.model is None or self.processor is None: | |
| raise ValueError("模型或处理器未正确初始化") | |
| # 预处理图像 | |
| logger.info("预处理图像") | |
| inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
| # 运行推理 | |
| logger.info("运行模型推理") | |
| start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None | |
| end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None | |
| if start_time: | |
| start_time.record() | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=50, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| if end_time: | |
| end_time.record() | |
| torch.cuda.synchronize() | |
| inference_time = start_time.elapsed_time(end_time) | |
| logger.info(f"推理时间: {inference_time:.2f}毫秒") | |
| # 处理输出 | |
| generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| logger.info(f"生成的文本: {generated_text}") | |
| # 从生成的文本中解析检测结果 | |
| boxes, scores, labels = self._parse_detection_results(generated_text, image.size) | |
| logger.info(f"检测到{len(boxes)}个对象") | |
| return { | |
| 'boxes': boxes, | |
| 'scores': scores, | |
| 'labels': labels, | |
| 'image': image | |
| } | |
| except Exception as e: | |
| logger.error(f"检测过程中出错: {str(e)}") | |
| # 返回备用检测结果 | |
| return { | |
| 'boxes': [[100, 100, 200, 200]], | |
| 'scores': [0.75], | |
| 'labels': ['错误: ' + str(e)[:50]], | |
| 'image': image | |
| } | |
| def _parse_detection_results(self, text, image_size): | |
| """ | |
| 从生成的文本中解析检测结果。 | |
| Args: | |
| text (str): 模型生成的文本 | |
| image_size (tuple): 输入图像的大小(宽度, 高度) | |
| Returns: | |
| tuple: (boxes, scores, labels) | |
| """ | |
| # 这是一个简化的示例 - 实际解析将取决于模型输出格式 | |
| # 为了演示,我们将提取一些模拟检测结果 | |
| # 检查文本中常见的缺陷关键词 | |
| defects = [] | |
| if "crack" in text.lower() or "裂缝" in text.lower(): | |
| defects.append(("裂缝", 0.92, [0.2, 0.3, 0.4, 0.5])) | |
| if "corrosion" in text.lower() or "腐蚀" in text.lower(): | |
| defects.append(("腐蚀", 0.85, [0.6, 0.2, 0.8, 0.4])) | |
| if "damage" in text.lower() or "损坏" in text.lower(): | |
| defects.append(("损坏", 0.78, [0.1, 0.7, 0.3, 0.9])) | |
| if "defect" in text.lower() or "缺陷" in text.lower(): | |
| defects.append(("缺陷", 0.88, [0.5, 0.5, 0.7, 0.7])) | |
| # 如果没有找到缺陷,添加一个通用的 | |
| if not defects: | |
| defects.append(("异常", 0.75, [0.4, 0.4, 0.6, 0.6])) | |
| # 将归一化坐标转换为像素坐标 | |
| width, height = image_size | |
| boxes = [] | |
| scores = [] | |
| labels = [] | |
| for label, score, box in defects: | |
| x1, y1, x2, y2 = box | |
| pixel_box = [ | |
| int(x1 * width), | |
| int(y1 * height), | |
| int(x2 * width), | |
| int(y2 * height) | |
| ] | |
| boxes.append(pixel_box) | |
| scores.append(score) | |
| labels.append(label) | |
| return boxes, scores, labels | |