Spaces:
Sleeping
Sleeping
chenxingqiang
commited on
Commit
·
3228ab0
1
Parent(s):
d6ecb31
Optimize model loading and improve user experience
Browse files- README.md +36 -0
- __pycache__/app.cpython-311.pyc +0 -0
- __pycache__/config.cpython-311.pyc +0 -0
- __pycache__/feature_extraction.cpython-311.pyc +0 -0
- __pycache__/model.cpython-311.pyc +0 -0
- app.py +648 -205
- config.py +1 -1
- create_space.py +11 -1
- feature_extraction.py +24 -7
- model.py +161 -57
- requirements.txt +13 -11
- run.py +38 -0
- test_app.py +129 -0
- utils.py +5 -3
README.md
CHANGED
|
@@ -38,6 +38,36 @@ This Hugging Face Space provides an interactive interface for analyzing radar im
|
|
| 38 |
3. View the detection results and analysis report
|
| 39 |
4. Access previous analyses through the history feature
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
## Technical Details
|
| 42 |
|
| 43 |
- **Model**: PaliGemma-3b fine-tuned for radar defect detection
|
|
@@ -52,6 +82,12 @@ The following environment variables need to be set in your Space:
|
|
| 52 |
- `HF_TOKEN`: Your Hugging Face token for accessing the model
|
| 53 |
- `DATABASE_URL` (optional): URL for the database connection
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
## Requirements
|
| 56 |
|
| 57 |
See `requirements.txt` for the complete list of dependencies.
|
|
|
|
| 38 |
3. View the detection results and analysis report
|
| 39 |
4. Access previous analyses through the history feature
|
| 40 |
|
| 41 |
+
## Setup Instructions
|
| 42 |
+
|
| 43 |
+
### Local Development
|
| 44 |
+
|
| 45 |
+
1. Clone this repository:
|
| 46 |
+
```bash
|
| 47 |
+
git clone https://huggingface.co/spaces/xingqiang/radar-analysis
|
| 48 |
+
cd radar-analysis
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
2. Install dependencies:
|
| 52 |
+
```bash
|
| 53 |
+
pip install -r requirements.txt
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
3. Set up environment variables:
|
| 57 |
+
- Create a `.env` file in the root directory
|
| 58 |
+
- Add your Hugging Face token: `HF_TOKEN=your_token_here`
|
| 59 |
+
|
| 60 |
+
4. Run the application:
|
| 61 |
+
```bash
|
| 62 |
+
python app.py
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### Hugging Face Space Deployment
|
| 66 |
+
|
| 67 |
+
1. Fork this repository to your Hugging Face account
|
| 68 |
+
2. Set the `HF_TOKEN` secret in your Space settings
|
| 69 |
+
3. Deploy the Space
|
| 70 |
+
|
| 71 |
## Technical Details
|
| 72 |
|
| 73 |
- **Model**: PaliGemma-3b fine-tuned for radar defect detection
|
|
|
|
| 82 |
- `HF_TOKEN`: Your Hugging Face token for accessing the model
|
| 83 |
- `DATABASE_URL` (optional): URL for the database connection
|
| 84 |
|
| 85 |
+
## Troubleshooting
|
| 86 |
+
|
| 87 |
+
- **Memory Issues**: The application will automatically switch to demo mode if there's not enough memory
|
| 88 |
+
- **Model Loading Errors**: Check that your Hugging Face token has access to the required model
|
| 89 |
+
- **Image Processing Errors**: Ensure uploaded images are in a supported format (PNG, JPG)
|
| 90 |
+
|
| 91 |
## Requirements
|
| 92 |
|
| 93 |
See `requirements.txt` for the complete list of dependencies.
|
__pycache__/app.cpython-311.pyc
ADDED
|
Binary file (53 kB). View file
|
|
|
__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (762 Bytes). View file
|
|
|
__pycache__/feature_extraction.cpython-311.pyc
ADDED
|
Binary file (4.36 kB). View file
|
|
|
__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
app.py
CHANGED
|
@@ -12,6 +12,7 @@ import plotly.express as px
|
|
| 12 |
import plotly.graph_objects as go
|
| 13 |
import pandas as pd
|
| 14 |
from functools import partial
|
|
|
|
| 15 |
|
| 16 |
from model import RadarDetectionModel
|
| 17 |
from feature_extraction import (calculate_amplitude, classify_amplitude,
|
|
@@ -22,6 +23,12 @@ from feature_extraction import (calculate_amplitude, classify_amplitude,
|
|
| 22 |
from report_generation import generate_report, render_report
|
| 23 |
from utils import plot_detection
|
| 24 |
from database import save_report, get_report_history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Set theme and styling
|
| 27 |
THEME = gr.themes.Soft(
|
|
@@ -35,32 +42,40 @@ THEME = gr.themes.Soft(
|
|
| 35 |
# Create a simple dark mode flag instead of custom theme
|
| 36 |
DARK_MODE = False
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
class TechnicalReportGenerator:
|
| 39 |
def __init__(self):
|
| 40 |
self.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 41 |
-
|
| 42 |
def generate_model_analysis(self, model_outputs):
|
| 43 |
"""Generate model-specific analysis section"""
|
| 44 |
model_section = "## Model Analysis\n\n"
|
| 45 |
-
|
| 46 |
# Image encoder analysis
|
| 47 |
model_section += "### Image Encoder (SigLIP-So400m) Analysis\n"
|
| 48 |
model_section += "- Feature extraction quality: {:.2f}%\n".format(model_outputs.get('feature_quality', 0) * 100)
|
| 49 |
model_section += "- Image encoding latency: {:.2f}ms\n".format(model_outputs.get('encoding_latency', 0))
|
| 50 |
model_section += "- Feature map dimensions: {}\n\n".format(model_outputs.get('feature_dimensions', 'N/A'))
|
| 51 |
-
|
| 52 |
# Text decoder analysis
|
| 53 |
model_section += "### Text Decoder (Gemma-2B) Analysis\n"
|
| 54 |
model_section += "- Text generation confidence: {:.2f}%\n".format(model_outputs.get('text_confidence', 0) * 100)
|
| 55 |
model_section += "- Decoding latency: {:.2f}ms\n".format(model_outputs.get('decoding_latency', 0))
|
| 56 |
model_section += "- Token processing rate: {:.2f} tokens/sec\n\n".format(model_outputs.get('token_rate', 0))
|
| 57 |
-
|
| 58 |
return model_section
|
| 59 |
|
| 60 |
def generate_detection_analysis(self, detection_results):
|
| 61 |
"""Generate detailed detection analysis section"""
|
| 62 |
detection_section = "## Detection Analysis\n\n"
|
| 63 |
-
|
| 64 |
# Detection metrics
|
| 65 |
detection_section += "### Object Detection Metrics\n"
|
| 66 |
detection_section += "| Metric | Value |\n"
|
|
@@ -72,29 +87,29 @@ class TechnicalReportGenerator:
|
|
| 72 |
detection_section += "| Processing Time | {:.2f}ms |\n\n".format(
|
| 73 |
detection_results.get('processing_time', 0)
|
| 74 |
)
|
| 75 |
-
|
| 76 |
# Detailed detection results
|
| 77 |
detection_section += "### Detection Details\n"
|
| 78 |
detection_section += "| Object | Confidence | Bounding Box |\n"
|
| 79 |
detection_section += "|--------|------------|---------------|\n"
|
| 80 |
-
|
| 81 |
boxes = detection_results.get('boxes', [])
|
| 82 |
scores = detection_results.get('scores', [])
|
| 83 |
labels = detection_results.get('labels', [])
|
| 84 |
-
|
| 85 |
for box, score, label in zip(boxes, scores, labels):
|
| 86 |
detection_section += "| {} | {:.2f}% | {} |\n".format(
|
| 87 |
label,
|
| 88 |
score * 100,
|
| 89 |
[round(coord, 2) for coord in box]
|
| 90 |
)
|
| 91 |
-
|
| 92 |
return detection_section
|
| 93 |
|
| 94 |
def generate_multimodal_analysis(self, mm_results):
|
| 95 |
"""Generate multimodal analysis section"""
|
| 96 |
mm_section = "## Multimodal Analysis\n\n"
|
| 97 |
-
|
| 98 |
# Feature correlation analysis
|
| 99 |
mm_section += "### Feature Correlation Analysis\n"
|
| 100 |
mm_section += "- Text-Image Alignment Score: {:.2f}%\n".format(
|
|
@@ -106,19 +121,19 @@ class TechnicalReportGenerator:
|
|
| 106 |
mm_section += "- Feature Space Correlation: {:.2f}\n\n".format(
|
| 107 |
mm_results.get('feature_correlation', 0)
|
| 108 |
)
|
| 109 |
-
|
| 110 |
return mm_section
|
| 111 |
|
| 112 |
def generate_performance_metrics(self, perf_data):
|
| 113 |
"""Generate performance metrics section"""
|
| 114 |
perf_section = "## Performance Metrics\n\n"
|
| 115 |
-
|
| 116 |
# System metrics
|
| 117 |
perf_section += "### System Performance\n"
|
| 118 |
perf_section += "- Total Processing Time: {:.2f}ms\n".format(perf_data.get('total_time', 0))
|
| 119 |
perf_section += "- Peak Memory Usage: {:.2f}MB\n".format(perf_data.get('peak_memory', 0))
|
| 120 |
perf_section += "- GPU Utilization: {:.2f}%\n\n".format(perf_data.get('gpu_util', 0))
|
| 121 |
-
|
| 122 |
# Pipeline metrics
|
| 123 |
perf_section += "### Pipeline Statistics\n"
|
| 124 |
perf_section += "| Stage | Time (ms) | Memory (MB) |\n"
|
|
@@ -130,108 +145,370 @@ class TechnicalReportGenerator:
|
|
| 130 |
stats.get('time', 0),
|
| 131 |
stats.get('memory', 0)
|
| 132 |
)
|
| 133 |
-
|
| 134 |
return perf_section
|
| 135 |
|
| 136 |
def generate_report(self, results):
|
| 137 |
"""Generate comprehensive technical report"""
|
| 138 |
report = f"# Technical Analysis Report\nGenerated at: {self.timestamp}\n\n"
|
| 139 |
-
|
| 140 |
# Add model analysis
|
| 141 |
report += self.generate_model_analysis(results.get('model_outputs', {}))
|
| 142 |
-
|
| 143 |
# Add detection analysis
|
| 144 |
report += self.generate_detection_analysis(results.get('detection_results', {}))
|
| 145 |
-
|
| 146 |
# Add multimodal analysis
|
| 147 |
report += self.generate_multimodal_analysis(results.get('multimodal_results', {}))
|
| 148 |
-
|
| 149 |
# Add performance metrics
|
| 150 |
report += self.generate_performance_metrics(results.get('performance_data', {}))
|
| 151 |
-
|
| 152 |
return report
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
try:
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
| 166 |
except Exception as e:
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
USE_DEMO_MODE = True
|
| 170 |
else:
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
def initialize_model():
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
USE_DEMO_MODE = True
|
| 197 |
-
return None
|
| 198 |
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
def create_confidence_chart(scores, labels):
|
| 202 |
"""Create a bar chart for confidence scores"""
|
| 203 |
if not scores or not labels:
|
| 204 |
return None
|
| 205 |
-
|
| 206 |
df = pd.DataFrame({
|
| 207 |
'Label': labels,
|
| 208 |
'Confidence': [score * 100 for score in scores]
|
| 209 |
})
|
| 210 |
-
|
| 211 |
fig = px.bar(
|
| 212 |
-
df,
|
| 213 |
-
x='Label',
|
| 214 |
y='Confidence',
|
| 215 |
title='Detection Confidence Scores',
|
| 216 |
labels={'Confidence': 'Confidence (%)'},
|
| 217 |
color='Confidence',
|
| 218 |
color_continuous_scale='viridis'
|
| 219 |
)
|
| 220 |
-
|
| 221 |
fig.update_layout(
|
| 222 |
xaxis_title='Detected Object',
|
| 223 |
yaxis_title='Confidence (%)',
|
| 224 |
yaxis_range=[0, 100],
|
| 225 |
template='plotly_white'
|
| 226 |
)
|
| 227 |
-
|
| 228 |
return fig
|
| 229 |
|
| 230 |
def create_feature_radar_chart(features):
|
| 231 |
"""Create a radar chart for feature analysis"""
|
| 232 |
categories = list(features.keys())
|
| 233 |
values = []
|
| 234 |
-
|
| 235 |
# Convert text classifications to numeric values (1-5 scale)
|
| 236 |
for feature in features.values():
|
| 237 |
if "High" in feature:
|
|
@@ -246,16 +523,16 @@ def create_feature_radar_chart(features):
|
|
| 246 |
values.append(1)
|
| 247 |
else:
|
| 248 |
values.append(0)
|
| 249 |
-
|
| 250 |
fig = go.Figure()
|
| 251 |
-
|
| 252 |
fig.add_trace(go.Scatterpolar(
|
| 253 |
r=values,
|
| 254 |
theta=categories,
|
| 255 |
fill='toself',
|
| 256 |
name='Feature Analysis'
|
| 257 |
))
|
| 258 |
-
|
| 259 |
fig.update_layout(
|
| 260 |
polar=dict(
|
| 261 |
radialaxis=dict(
|
|
@@ -266,108 +543,109 @@ def create_feature_radar_chart(features):
|
|
| 266 |
title='Feature Analysis Radar Chart',
|
| 267 |
template='plotly_white'
|
| 268 |
)
|
| 269 |
-
|
| 270 |
return fig
|
| 271 |
|
| 272 |
def create_heatmap(image_array):
|
| 273 |
"""Create a heatmap visualization of the image intensity"""
|
| 274 |
if image_array is None:
|
| 275 |
return None
|
| 276 |
-
|
| 277 |
# Convert to grayscale if needed
|
| 278 |
if len(image_array.shape) == 3 and image_array.shape[2] == 3:
|
| 279 |
gray_img = np.mean(image_array, axis=2)
|
| 280 |
else:
|
| 281 |
gray_img = image_array
|
| 282 |
-
|
| 283 |
fig = px.imshow(
|
| 284 |
gray_img,
|
| 285 |
color_continuous_scale='inferno',
|
| 286 |
title='Signal Intensity Heatmap'
|
| 287 |
)
|
| 288 |
-
|
| 289 |
fig.update_layout(
|
| 290 |
xaxis_title='X Position',
|
| 291 |
yaxis_title='Y Position',
|
| 292 |
template='plotly_white'
|
| 293 |
)
|
| 294 |
-
|
| 295 |
return fig
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
def process_image_streaming(image, generate_tech_report=False, progress=gr.Progress()):
|
| 298 |
-
"""
|
| 299 |
if image is None:
|
| 300 |
-
raise gr.Error("
|
| 301 |
|
| 302 |
-
#
|
| 303 |
-
progress(0.1, desc="
|
|
|
|
| 304 |
global model, USE_DEMO_MODE
|
| 305 |
|
| 306 |
if not USE_DEMO_MODE:
|
| 307 |
-
model
|
| 308 |
-
if
|
| 309 |
-
progress(0.15, desc="
|
| 310 |
USE_DEMO_MODE = True
|
| 311 |
|
| 312 |
try:
|
| 313 |
-
#
|
| 314 |
if isinstance(image, np.ndarray):
|
| 315 |
image = Image.fromarray(image)
|
| 316 |
|
| 317 |
-
#
|
| 318 |
-
progress(0.2, desc="
|
|
|
|
| 319 |
|
| 320 |
if USE_DEMO_MODE:
|
| 321 |
-
#
|
| 322 |
detection_result = {
|
| 323 |
'boxes': [[100, 100, 200, 200], [300, 300, 400, 400]],
|
| 324 |
'scores': [0.92, 0.85],
|
| 325 |
-
'labels': ['
|
| 326 |
'image': image
|
| 327 |
}
|
| 328 |
else:
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
distribution_class = classify_distribution_range(distribution_range)
|
| 343 |
-
else:
|
| 344 |
-
distribution_class = "No defects detected"
|
| 345 |
-
|
| 346 |
-
attenuation_rate = calculate_attenuation_rate(np_image)
|
| 347 |
-
attenuation_class = classify_attenuation_rate(attenuation_rate)
|
| 348 |
-
|
| 349 |
-
reflection_count = count_reflections(np_image)
|
| 350 |
-
reflection_class = classify_reflections(reflection_count)
|
| 351 |
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
"Attenuation Rate": attenuation_class,
|
| 356 |
-
"Reflection Count": reflection_class
|
| 357 |
-
}
|
| 358 |
|
| 359 |
-
#
|
| 360 |
-
progress(0.5, desc="
|
| 361 |
confidence_chart = create_confidence_chart(
|
| 362 |
-
detection_result.get('scores', []),
|
| 363 |
detection_result.get('labels', [])
|
| 364 |
)
|
| 365 |
|
| 366 |
feature_chart = create_feature_radar_chart(features)
|
| 367 |
-
heatmap = create_heatmap(
|
| 368 |
|
| 369 |
-
#
|
| 370 |
-
progress(0.6, desc="
|
| 371 |
start_time = time.time()
|
| 372 |
performance_data = {
|
| 373 |
'pipeline_stats': {},
|
|
@@ -375,7 +653,7 @@ def process_image_streaming(image, generate_tech_report=False, progress=gr.Progr
|
|
| 375 |
'gpu_util': 0
|
| 376 |
}
|
| 377 |
|
| 378 |
-
#
|
| 379 |
stage_start = time.time()
|
| 380 |
detection_results = detection_result
|
| 381 |
detection_results['processing_time'] = (time.time() - stage_start) * 1000
|
|
@@ -384,7 +662,7 @@ def process_image_streaming(image, generate_tech_report=False, progress=gr.Progr
|
|
| 384 |
'memory': get_memory_usage()
|
| 385 |
}
|
| 386 |
|
| 387 |
-
#
|
| 388 |
stage_start = time.time()
|
| 389 |
model_outputs = {
|
| 390 |
'feature_quality': 0.85,
|
|
@@ -399,7 +677,7 @@ def process_image_streaming(image, generate_tech_report=False, progress=gr.Progr
|
|
| 399 |
'memory': get_memory_usage()
|
| 400 |
}
|
| 401 |
|
| 402 |
-
#
|
| 403 |
stage_start = time.time()
|
| 404 |
multimodal_results = {
|
| 405 |
'alignment_score': 0.78,
|
|
@@ -411,20 +689,20 @@ def process_image_streaming(image, generate_tech_report=False, progress=gr.Progr
|
|
| 411 |
'memory': get_memory_usage()
|
| 412 |
}
|
| 413 |
|
| 414 |
-
#
|
| 415 |
performance_data['total_time'] = (time.time() - start_time) * 1000
|
| 416 |
performance_data['peak_memory'] = get_peak_memory_usage()
|
| 417 |
performance_data['gpu_util'] = get_gpu_utilization()
|
| 418 |
|
| 419 |
-
#
|
| 420 |
-
progress(0.8, desc="
|
| 421 |
analysis_report = generate_report(detection_result, features)
|
| 422 |
|
| 423 |
-
#
|
| 424 |
output_image = plot_detection(image, detection_result)
|
| 425 |
|
| 426 |
if generate_tech_report:
|
| 427 |
-
#
|
| 428 |
tech_report_data = {
|
| 429 |
'model_outputs': model_outputs,
|
| 430 |
'detection_results': detection_results,
|
|
@@ -432,23 +710,29 @@ def process_image_streaming(image, generate_tech_report=False, progress=gr.Progr
|
|
| 432 |
'performance_data': performance_data
|
| 433 |
}
|
| 434 |
|
| 435 |
-
#
|
| 436 |
tech_report = TechnicalReportGenerator().generate_report(tech_report_data)
|
| 437 |
|
| 438 |
-
#
|
| 439 |
report_path = "temp_tech_report.md"
|
| 440 |
with open(report_path, "w") as f:
|
| 441 |
f.write(tech_report)
|
| 442 |
|
| 443 |
-
progress(1.0, desc="
|
|
|
|
|
|
|
| 444 |
return output_image, analysis_report, report_path, confidence_chart, feature_chart, heatmap
|
| 445 |
|
| 446 |
-
progress(1.0, desc="
|
|
|
|
|
|
|
| 447 |
return output_image, analysis_report, None, confidence_chart, feature_chart, heatmap
|
| 448 |
|
| 449 |
except Exception as e:
|
| 450 |
-
error_msg = f"
|
| 451 |
print(error_msg)
|
|
|
|
|
|
|
| 452 |
raise gr.Error(error_msg)
|
| 453 |
|
| 454 |
def display_history():
|
|
@@ -472,12 +756,25 @@ def display_history():
|
|
| 472 |
def get_memory_usage():
|
| 473 |
"""Get current memory usage in MB"""
|
| 474 |
process = psutil.Process()
|
| 475 |
-
|
|
|
|
| 476 |
|
| 477 |
def get_peak_memory_usage():
|
| 478 |
"""Get peak memory usage in MB"""
|
| 479 |
-
|
| 480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
def get_gpu_utilization():
|
| 483 |
"""Get GPU utilization percentage"""
|
|
@@ -488,128 +785,274 @@ def get_gpu_utilization():
|
|
| 488 |
pass
|
| 489 |
return 0
|
| 490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
def toggle_dark_mode():
|
| 492 |
"""Toggle between light and dark themes"""
|
| 493 |
global DARK_MODE
|
| 494 |
DARK_MODE = not DARK_MODE
|
| 495 |
return gr.Theme.darkmode() if DARK_MODE else THEME
|
| 496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
# Create Gradio interface
|
| 498 |
with gr.Blocks(theme=THEME) as iface:
|
| 499 |
theme_state = gr.State(THEME)
|
| 500 |
-
|
| 501 |
with gr.Row():
|
| 502 |
-
gr.Markdown("#
|
| 503 |
-
dark_mode_btn = gr.Button("🌓
|
| 504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
if USE_DEMO_MODE:
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
""", elem_id="demo-mode-warning")
|
| 516 |
-
|
| 517 |
-
gr.Markdown("
|
| 518 |
-
|
| 519 |
with gr.Tabs() as tabs:
|
| 520 |
-
with gr.TabItem("
|
| 521 |
with gr.Row():
|
| 522 |
with gr.Column(scale=1):
|
| 523 |
-
with gr.Accordion("
|
| 524 |
input_image = gr.Image(
|
| 525 |
-
type="pil",
|
| 526 |
-
label="
|
| 527 |
elem_id="input-image",
|
| 528 |
sources=["upload", "webcam", "clipboard"],
|
| 529 |
tool="editor"
|
| 530 |
)
|
| 531 |
tech_report_checkbox = gr.Checkbox(
|
| 532 |
-
label="
|
| 533 |
value=False,
|
| 534 |
-
info="
|
| 535 |
)
|
| 536 |
analyze_button = gr.Button(
|
| 537 |
-
"
|
| 538 |
variant="primary",
|
| 539 |
elem_id="analyze-btn"
|
| 540 |
)
|
| 541 |
-
|
| 542 |
with gr.Column(scale=2):
|
| 543 |
-
with gr.Accordion("
|
| 544 |
output_image = gr.Image(
|
| 545 |
-
type="pil",
|
| 546 |
-
label="
|
| 547 |
elem_id="output-image"
|
| 548 |
)
|
| 549 |
-
|
| 550 |
-
with gr.Accordion("
|
| 551 |
output_report = gr.HTML(
|
| 552 |
-
label="
|
| 553 |
elem_id="analysis-report"
|
| 554 |
)
|
| 555 |
tech_report_output = gr.File(
|
| 556 |
-
label="
|
| 557 |
elem_id="tech-report"
|
| 558 |
)
|
| 559 |
-
|
| 560 |
with gr.Row():
|
| 561 |
with gr.Column():
|
| 562 |
confidence_plot = gr.Plot(
|
| 563 |
-
label="
|
| 564 |
elem_id="confidence-plot"
|
| 565 |
)
|
| 566 |
-
|
| 567 |
with gr.Column():
|
| 568 |
feature_plot = gr.Plot(
|
| 569 |
-
label="
|
| 570 |
elem_id="feature-plot"
|
| 571 |
)
|
| 572 |
-
|
| 573 |
with gr.Row():
|
| 574 |
heatmap_plot = gr.Plot(
|
| 575 |
-
label="
|
| 576 |
elem_id="heatmap-plot"
|
| 577 |
)
|
| 578 |
-
|
| 579 |
-
with gr.TabItem("
|
| 580 |
with gr.Row():
|
| 581 |
-
history_button = gr.Button("
|
| 582 |
history_output = gr.HTML(elem_id="history-output")
|
| 583 |
-
|
| 584 |
-
with gr.TabItem("
|
| 585 |
gr.Markdown("""
|
| 586 |
-
##
|
| 587 |
-
|
| 588 |
-
1.
|
| 589 |
-
2.
|
| 590 |
-
3.
|
| 591 |
-
4.
|
| 592 |
-
-
|
| 593 |
-
-
|
| 594 |
-
-
|
| 595 |
-
-
|
| 596 |
-
|
| 597 |
-
##
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
- **Ctrl+A**: Trigger analysis
|
| 604 |
-
- **Ctrl+D**: Toggle dark mode
|
| 605 |
-
|
| 606 |
-
## Troubleshooting
|
| 607 |
-
|
| 608 |
-
- If the analysis fails, try uploading a different image format
|
| 609 |
-
- Ensure the image is a valid radar scan
|
| 610 |
-
- For technical issues, check the console logs
|
| 611 |
""")
|
| 612 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
# Set up event handlers
|
| 614 |
dark_mode_btn.click(
|
| 615 |
fn=toggle_dark_mode,
|
|
@@ -617,21 +1060,21 @@ with gr.Blocks(theme=THEME) as iface:
|
|
| 617 |
outputs=[iface],
|
| 618 |
api_name="toggle_theme"
|
| 619 |
)
|
| 620 |
-
|
| 621 |
analyze_button.click(
|
| 622 |
fn=process_image_streaming,
|
| 623 |
inputs=[input_image, tech_report_checkbox],
|
| 624 |
outputs=[output_image, output_report, tech_report_output, confidence_plot, feature_plot, heatmap_plot],
|
| 625 |
api_name="analyze"
|
| 626 |
)
|
| 627 |
-
|
| 628 |
history_button.click(
|
| 629 |
fn=display_history,
|
| 630 |
inputs=[],
|
| 631 |
outputs=[history_output],
|
| 632 |
api_name="history"
|
| 633 |
)
|
| 634 |
-
|
| 635 |
# Add keyboard shortcuts
|
| 636 |
iface.load(lambda: None, None, None, _js="""
|
| 637 |
() => {
|
|
@@ -640,11 +1083,11 @@ with gr.Blocks(theme=THEME) as iface:
|
|
| 640 |
document.getElementById('analyze-btn').click();
|
| 641 |
}
|
| 642 |
if (e.key === 'd' && e.ctrlKey) {
|
| 643 |
-
document.querySelector('button:contains("
|
| 644 |
}
|
| 645 |
});
|
| 646 |
}
|
| 647 |
""")
|
| 648 |
|
| 649 |
# Launch the interface
|
| 650 |
-
|
|
|
|
| 12 |
import plotly.graph_objects as go
|
| 13 |
import pandas as pd
|
| 14 |
from functools import partial
|
| 15 |
+
import logging
|
| 16 |
|
| 17 |
from model import RadarDetectionModel
|
| 18 |
from feature_extraction import (calculate_amplitude, classify_amplitude,
|
|
|
|
| 23 |
from report_generation import generate_report, render_report
|
| 24 |
from utils import plot_detection
|
| 25 |
from database import save_report, get_report_history
|
| 26 |
+
from config import MODEL_NAME
|
| 27 |
+
|
| 28 |
+
# Configure logging
|
| 29 |
+
logging.basicConfig(level=logging.INFO,
|
| 30 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
|
| 33 |
# Set theme and styling
|
| 34 |
THEME = gr.themes.Soft(
|
|
|
|
| 42 |
# Create a simple dark mode flag instead of custom theme
|
| 43 |
DARK_MODE = False
|
| 44 |
|
| 45 |
+
# Global variables
|
| 46 |
+
model = None
|
| 47 |
+
USE_DEMO_MODE = False
|
| 48 |
+
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HF_TOCKEN")
|
| 49 |
+
|
| 50 |
+
# 添加一个标志,表示是否已经尝试过初始化模型
|
| 51 |
+
MODEL_INIT_ATTEMPTED = False
|
| 52 |
+
|
| 53 |
class TechnicalReportGenerator:
|
| 54 |
def __init__(self):
|
| 55 |
self.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 56 |
+
|
| 57 |
def generate_model_analysis(self, model_outputs):
|
| 58 |
"""Generate model-specific analysis section"""
|
| 59 |
model_section = "## Model Analysis\n\n"
|
| 60 |
+
|
| 61 |
# Image encoder analysis
|
| 62 |
model_section += "### Image Encoder (SigLIP-So400m) Analysis\n"
|
| 63 |
model_section += "- Feature extraction quality: {:.2f}%\n".format(model_outputs.get('feature_quality', 0) * 100)
|
| 64 |
model_section += "- Image encoding latency: {:.2f}ms\n".format(model_outputs.get('encoding_latency', 0))
|
| 65 |
model_section += "- Feature map dimensions: {}\n\n".format(model_outputs.get('feature_dimensions', 'N/A'))
|
| 66 |
+
|
| 67 |
# Text decoder analysis
|
| 68 |
model_section += "### Text Decoder (Gemma-2B) Analysis\n"
|
| 69 |
model_section += "- Text generation confidence: {:.2f}%\n".format(model_outputs.get('text_confidence', 0) * 100)
|
| 70 |
model_section += "- Decoding latency: {:.2f}ms\n".format(model_outputs.get('decoding_latency', 0))
|
| 71 |
model_section += "- Token processing rate: {:.2f} tokens/sec\n\n".format(model_outputs.get('token_rate', 0))
|
| 72 |
+
|
| 73 |
return model_section
|
| 74 |
|
| 75 |
def generate_detection_analysis(self, detection_results):
|
| 76 |
"""Generate detailed detection analysis section"""
|
| 77 |
detection_section = "## Detection Analysis\n\n"
|
| 78 |
+
|
| 79 |
# Detection metrics
|
| 80 |
detection_section += "### Object Detection Metrics\n"
|
| 81 |
detection_section += "| Metric | Value |\n"
|
|
|
|
| 87 |
detection_section += "| Processing Time | {:.2f}ms |\n\n".format(
|
| 88 |
detection_results.get('processing_time', 0)
|
| 89 |
)
|
| 90 |
+
|
| 91 |
# Detailed detection results
|
| 92 |
detection_section += "### Detection Details\n"
|
| 93 |
detection_section += "| Object | Confidence | Bounding Box |\n"
|
| 94 |
detection_section += "|--------|------------|---------------|\n"
|
| 95 |
+
|
| 96 |
boxes = detection_results.get('boxes', [])
|
| 97 |
scores = detection_results.get('scores', [])
|
| 98 |
labels = detection_results.get('labels', [])
|
| 99 |
+
|
| 100 |
for box, score, label in zip(boxes, scores, labels):
|
| 101 |
detection_section += "| {} | {:.2f}% | {} |\n".format(
|
| 102 |
label,
|
| 103 |
score * 100,
|
| 104 |
[round(coord, 2) for coord in box]
|
| 105 |
)
|
| 106 |
+
|
| 107 |
return detection_section
|
| 108 |
|
| 109 |
def generate_multimodal_analysis(self, mm_results):
|
| 110 |
"""Generate multimodal analysis section"""
|
| 111 |
mm_section = "## Multimodal Analysis\n\n"
|
| 112 |
+
|
| 113 |
# Feature correlation analysis
|
| 114 |
mm_section += "### Feature Correlation Analysis\n"
|
| 115 |
mm_section += "- Text-Image Alignment Score: {:.2f}%\n".format(
|
|
|
|
| 121 |
mm_section += "- Feature Space Correlation: {:.2f}\n\n".format(
|
| 122 |
mm_results.get('feature_correlation', 0)
|
| 123 |
)
|
| 124 |
+
|
| 125 |
return mm_section
|
| 126 |
|
| 127 |
def generate_performance_metrics(self, perf_data):
|
| 128 |
"""Generate performance metrics section"""
|
| 129 |
perf_section = "## Performance Metrics\n\n"
|
| 130 |
+
|
| 131 |
# System metrics
|
| 132 |
perf_section += "### System Performance\n"
|
| 133 |
perf_section += "- Total Processing Time: {:.2f}ms\n".format(perf_data.get('total_time', 0))
|
| 134 |
perf_section += "- Peak Memory Usage: {:.2f}MB\n".format(perf_data.get('peak_memory', 0))
|
| 135 |
perf_section += "- GPU Utilization: {:.2f}%\n\n".format(perf_data.get('gpu_util', 0))
|
| 136 |
+
|
| 137 |
# Pipeline metrics
|
| 138 |
perf_section += "### Pipeline Statistics\n"
|
| 139 |
perf_section += "| Stage | Time (ms) | Memory (MB) |\n"
|
|
|
|
| 145 |
stats.get('time', 0),
|
| 146 |
stats.get('memory', 0)
|
| 147 |
)
|
| 148 |
+
|
| 149 |
return perf_section
|
| 150 |
|
| 151 |
def generate_report(self, results):
|
| 152 |
"""Generate comprehensive technical report"""
|
| 153 |
report = f"# Technical Analysis Report\nGenerated at: {self.timestamp}\n\n"
|
| 154 |
+
|
| 155 |
# Add model analysis
|
| 156 |
report += self.generate_model_analysis(results.get('model_outputs', {}))
|
| 157 |
+
|
| 158 |
# Add detection analysis
|
| 159 |
report += self.generate_detection_analysis(results.get('detection_results', {}))
|
| 160 |
+
|
| 161 |
# Add multimodal analysis
|
| 162 |
report += self.generate_multimodal_analysis(results.get('multimodal_results', {}))
|
| 163 |
+
|
| 164 |
# Add performance metrics
|
| 165 |
report += self.generate_performance_metrics(results.get('performance_data', {}))
|
| 166 |
+
|
| 167 |
return report
|
| 168 |
|
| 169 |
+
def check_available_memory():
|
| 170 |
+
"""Check available system memory in MB"""
|
| 171 |
+
try:
|
| 172 |
+
import psutil
|
| 173 |
+
vm = psutil.virtual_memory()
|
| 174 |
+
available_mb = vm.available / (1024 * 1024)
|
| 175 |
+
total_mb = vm.total / (1024 * 1024)
|
| 176 |
+
print(f"Available memory: {available_mb:.2f}MB out of {total_mb:.2f}MB total")
|
| 177 |
+
return available_mb
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"Error checking memory: {str(e)}")
|
| 180 |
+
return 0
|
| 181 |
+
|
| 182 |
+
def monitor_memory_during_loading(model_name, use_auth_token=None):
|
| 183 |
+
"""Monitor memory usage during model loading and abort if it gets too high"""
|
| 184 |
+
global USE_DEMO_MODE
|
| 185 |
+
|
| 186 |
+
try:
|
| 187 |
+
# Initial memory check
|
| 188 |
+
initial_memory = get_memory_usage()
|
| 189 |
+
print(f"Initial memory usage: {initial_memory:.2f}MB")
|
| 190 |
+
|
| 191 |
+
# Start loading processor
|
| 192 |
+
print(f"Loading processor from {model_name}")
|
| 193 |
+
if use_auth_token:
|
| 194 |
+
processor = AutoProcessor.from_pretrained(model_name, use_auth_token=use_auth_token)
|
| 195 |
+
else:
|
| 196 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
| 197 |
+
|
| 198 |
+
# Check memory after processor loading
|
| 199 |
+
after_processor_memory = get_memory_usage()
|
| 200 |
+
print(f"Memory after processor loading: {after_processor_memory:.2f}MB (Δ: {after_processor_memory - initial_memory:.2f}MB)")
|
| 201 |
+
|
| 202 |
+
# Check if memory is getting too high
|
| 203 |
+
available_memory = check_available_memory()
|
| 204 |
+
if available_memory < 4000: # Less than 4GB available
|
| 205 |
+
print(f"Warning: Only {available_memory:.2f}MB memory available after loading processor")
|
| 206 |
+
print("Aborting model loading to avoid out-of-memory error")
|
| 207 |
+
USE_DEMO_MODE = True
|
| 208 |
+
return None, None
|
| 209 |
+
|
| 210 |
+
# Start loading model with 8-bit quantization
|
| 211 |
+
print(f"Loading model from {model_name} with 8-bit quantization")
|
| 212 |
+
if use_auth_token:
|
| 213 |
+
model = AutoModelForVision2Seq.from_pretrained(
|
| 214 |
+
model_name,
|
| 215 |
+
use_auth_token=use_auth_token,
|
| 216 |
+
load_in_8bit=True,
|
| 217 |
+
device_map="auto"
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
model = AutoModelForVision2Seq.from_pretrained(
|
| 221 |
+
model_name,
|
| 222 |
+
load_in_8bit=True,
|
| 223 |
+
device_map="auto"
|
| 224 |
+
)
|
| 225 |
|
| 226 |
+
# Check memory after model loading
|
| 227 |
+
after_model_memory = get_memory_usage()
|
| 228 |
+
print(f"Memory after model loading: {after_model_memory:.2f}MB (Δ: {after_model_memory - after_processor_memory:.2f}MB)")
|
| 229 |
+
|
| 230 |
+
# Set model to evaluation mode
|
| 231 |
+
model.eval()
|
| 232 |
+
|
| 233 |
+
return processor, model
|
| 234 |
+
except Exception as e:
|
| 235 |
+
print(f"Error during monitored model loading: {str(e)}")
|
| 236 |
+
USE_DEMO_MODE = True
|
| 237 |
+
return None, None
|
| 238 |
+
|
| 239 |
+
def is_running_in_space():
|
| 240 |
+
"""Check if we're running in a Hugging Face Space environment"""
|
| 241 |
+
return os.environ.get("SPACE_ID") is not None
|
| 242 |
+
|
| 243 |
+
def is_container_environment():
|
| 244 |
+
"""Check if we're running in a container environment"""
|
| 245 |
+
return os.path.exists("/.dockerenv") or os.path.exists("/run/.containerenv")
|
| 246 |
+
|
| 247 |
+
def is_cpu_only():
|
| 248 |
+
"""Check if we're running in a CPU-only environment"""
|
| 249 |
+
return not torch.cuda.is_available()
|
| 250 |
+
|
| 251 |
+
def is_low_memory_environment():
|
| 252 |
+
"""Check if we're running in a low-memory environment"""
|
| 253 |
+
available_memory = check_available_memory()
|
| 254 |
+
return available_memory < 8000 # Less than 8GB available
|
| 255 |
+
|
| 256 |
+
def is_development_environment():
|
| 257 |
+
"""Check if we're running in a development environment"""
|
| 258 |
+
return not (is_running_in_space() or is_container_environment())
|
| 259 |
+
|
| 260 |
+
def is_debug_mode():
|
| 261 |
+
"""Check if we're running in debug mode"""
|
| 262 |
+
return os.environ.get("DEBUG", "").lower() in ("1", "true", "yes")
|
| 263 |
+
|
| 264 |
+
def is_test_mode():
|
| 265 |
+
"""Check if we're running in test mode"""
|
| 266 |
+
return os.environ.get("TEST", "").lower() in ("1", "true", "yes")
|
| 267 |
+
|
| 268 |
+
def is_low_memory_container():
|
| 269 |
+
"""Check if we're running in a container with memory limits"""
|
| 270 |
+
if not is_container_environment():
|
| 271 |
+
return False
|
| 272 |
+
|
| 273 |
+
# Check if cgroup memory limit is set
|
| 274 |
+
try:
|
| 275 |
+
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
|
| 276 |
+
limit = int(f.read().strip())
|
| 277 |
+
# Convert to MB
|
| 278 |
+
limit_mb = limit / (1024 * 1024)
|
| 279 |
+
print(f"Container memory limit: {limit_mb:.2f}MB")
|
| 280 |
+
return limit_mb < 20000 # Less than 20GB
|
| 281 |
+
except:
|
| 282 |
+
# If we can't read the limit, assume it's a low-memory container
|
| 283 |
+
return True
|
| 284 |
+
|
| 285 |
+
def is_space_hardware_type(hardware_type):
|
| 286 |
+
"""Check if we're running in a Hugging Face Space with a specific hardware type"""
|
| 287 |
+
if not is_running_in_space():
|
| 288 |
+
return False
|
| 289 |
+
|
| 290 |
+
# Check if SPACE_HARDWARE environment variable matches the specified type
|
| 291 |
+
return os.environ.get("SPACE_HARDWARE", "").lower() == hardware_type.lower()
|
| 292 |
+
|
| 293 |
+
def get_space_hardware_tier():
|
| 294 |
+
"""Get the hardware tier of the Hugging Face Space"""
|
| 295 |
+
if not is_running_in_space():
|
| 296 |
+
return "Not a Space"
|
| 297 |
+
|
| 298 |
+
hardware = os.environ.get("SPACE_HARDWARE", "unknown")
|
| 299 |
+
|
| 300 |
+
# Determine the tier based on hardware type
|
| 301 |
+
if hardware.lower() == "cpu":
|
| 302 |
+
return "Basic (CPU)"
|
| 303 |
+
elif hardware.lower() == "t4-small":
|
| 304 |
+
return "Basic (GPU)"
|
| 305 |
+
elif hardware.lower() == "t4-medium":
|
| 306 |
+
return "Standard"
|
| 307 |
+
elif hardware.lower() == "a10g-small":
|
| 308 |
+
return "Pro"
|
| 309 |
+
elif hardware.lower() == "a10g-large":
|
| 310 |
+
return "Pro+"
|
| 311 |
+
elif hardware.lower() == "a100-large":
|
| 312 |
+
return "Enterprise"
|
| 313 |
+
else:
|
| 314 |
+
return f"Unknown ({hardware})"
|
| 315 |
+
|
| 316 |
+
def get_space_hardware_memory():
|
| 317 |
+
"""Get the memory size of the Hugging Face Space hardware in GB"""
|
| 318 |
+
if not is_running_in_space():
|
| 319 |
+
return 0
|
| 320 |
+
|
| 321 |
+
hardware = os.environ.get("SPACE_HARDWARE", "unknown").lower()
|
| 322 |
+
|
| 323 |
+
# Determine the memory size based on hardware type
|
| 324 |
+
if hardware == "cpu":
|
| 325 |
+
return 16 # 16GB for CPU
|
| 326 |
+
elif hardware == "t4-small":
|
| 327 |
+
return 16 # 16GB for T4 Small
|
| 328 |
+
elif hardware == "t4-medium":
|
| 329 |
+
return 16 # 16GB for T4 Medium
|
| 330 |
+
elif hardware == "a10g-small":
|
| 331 |
+
return 24 # 24GB for A10G Small
|
| 332 |
+
elif hardware == "a10g-large":
|
| 333 |
+
return 40 # 40GB for A10G Large
|
| 334 |
+
elif hardware == "a100-large":
|
| 335 |
+
return 80 # 80GB for A100 Large
|
| 336 |
+
else:
|
| 337 |
+
return 16 # Default to 16GB
|
| 338 |
+
|
| 339 |
+
def get_total_system_memory():
|
| 340 |
+
"""Get total system memory in MB"""
|
| 341 |
try:
|
| 342 |
+
import psutil
|
| 343 |
+
total_bytes = psutil.virtual_memory().total
|
| 344 |
+
total_mb = total_bytes / (1024 * 1024)
|
| 345 |
+
return total_mb
|
| 346 |
except Exception as e:
|
| 347 |
+
print(f"Error getting total system memory: {str(e)}")
|
| 348 |
+
return 0
|
| 349 |
+
|
| 350 |
+
def estimate_model_memory_requirements():
|
| 351 |
+
"""Estimate the memory requirements for the model"""
|
| 352 |
+
# This is a placeholder implementation. You might want to implement a more accurate estimation based on your model's architecture and typical input sizes.
|
| 353 |
+
try:
|
| 354 |
+
HF_TOCKEN = os.getenv("HF_TOCKEN")
|
| 355 |
+
|
| 356 |
+
# Print startup message
|
| 357 |
+
print("===== Application Startup at", datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "=====")
|
| 358 |
+
|
| 359 |
+
# Get system memory information
|
| 360 |
+
total_memory = get_total_system_memory()
|
| 361 |
+
required_memory = estimate_model_memory_requirements()
|
| 362 |
+
recommended_tier = get_recommended_space_tier()
|
| 363 |
+
print(f"NOTICE: Total system memory: {total_memory:.2f}MB")
|
| 364 |
+
print(f"NOTICE: Estimated model memory requirement: {required_memory:.2f}MB")
|
| 365 |
+
print(f"NOTICE: Recommended Space tier: {recommended_tier}")
|
| 366 |
+
|
| 367 |
+
if is_test_mode():
|
| 368 |
+
print("NOTICE: Running in TEST mode")
|
| 369 |
+
print("NOTICE: Using mock data and responses")
|
| 370 |
+
USE_DEMO_MODE = True
|
| 371 |
+
|
| 372 |
+
if is_debug_mode():
|
| 373 |
+
print("NOTICE: Running in DEBUG mode")
|
| 374 |
+
print("NOTICE: Additional logging and diagnostics will be enabled")
|
| 375 |
+
|
| 376 |
+
if is_development_environment():
|
| 377 |
+
print("NOTICE: Running in development environment")
|
| 378 |
+
print("NOTICE: Full model capabilities may be available depending on system resources")
|
| 379 |
+
|
| 380 |
+
if is_running_in_space():
|
| 381 |
+
print("NOTICE: Running in Hugging Face Space environment")
|
| 382 |
+
|
| 383 |
+
# Check Space hardware type
|
| 384 |
+
hardware_type = get_space_hardware_type()
|
| 385 |
+
hardware_tier = get_space_hardware_tier()
|
| 386 |
+
hardware_memory = get_space_hardware_memory()
|
| 387 |
+
print(f"NOTICE: Space hardware type: {hardware_type} (Tier: {hardware_tier}, Memory: {hardware_memory}GB)")
|
| 388 |
+
|
| 389 |
+
if has_enough_memory_for_model():
|
| 390 |
+
print("NOTICE: This Space has enough memory for the model, but we're still forcing demo mode for stability")
|
| 391 |
+
else:
|
| 392 |
+
print(f"NOTICE: This Space does NOT have enough memory for the model (Need: {required_memory:.2f}MB, Have: {hardware_memory*1024:.2f}MB)")
|
| 393 |
+
print(f"NOTICE: Recommended Space tier: {recommended_tier}")
|
| 394 |
+
|
| 395 |
+
print("NOTICE: FORCING DEMO MODE to avoid 'Memory limit exceeded (16Gi)' error")
|
| 396 |
+
print("NOTICE: The PaliGemma model is too large for the 16GB memory limit in Spaces")
|
| 397 |
+
print("NOTICE: To use the full model, please run this application locally")
|
| 398 |
+
USE_DEMO_MODE = True
|
| 399 |
+
elif is_container_environment():
|
| 400 |
+
print("NOTICE: Running in a container environment")
|
| 401 |
+
print("NOTICE: Memory limits may be enforced by the container runtime")
|
| 402 |
+
|
| 403 |
+
if is_cpu_only():
|
| 404 |
+
print("NOTICE: Running in CPU-only environment")
|
| 405 |
+
print("NOTICE: Model loading and inference will be slower")
|
| 406 |
+
|
| 407 |
+
# Check available memory
|
| 408 |
+
available_memory = check_available_memory()
|
| 409 |
+
print(f"NOTICE: Available memory: {available_memory:.2f}MB")
|
| 410 |
+
|
| 411 |
+
if is_low_memory_environment() and not USE_DEMO_MODE:
|
| 412 |
+
print("NOTICE: Running in a low-memory environment")
|
| 413 |
+
print("NOTICE: Enabling DEMO MODE to avoid memory issues")
|
| 414 |
USE_DEMO_MODE = True
|
| 415 |
else:
|
| 416 |
+
# Check available memory before loading
|
| 417 |
+
available_memory = check_available_memory()
|
| 418 |
+
if available_memory < 8000: # If less than 8GB available
|
| 419 |
+
print(f"Warning: Only {available_memory:.2f}MB memory available, which may not be enough for the full model")
|
| 420 |
+
return required_memory
|
| 421 |
+
except Exception as e:
|
| 422 |
+
print(f"Warning: Model initialization failed: {str(e)}")
|
| 423 |
+
print("Falling back to demo mode.")
|
| 424 |
+
USE_DEMO_MODE = True
|
| 425 |
+
return 0
|
| 426 |
|
| 427 |
def initialize_model():
|
| 428 |
+
"""
|
| 429 |
+
仅在需要时初始化模型,不会在应用启动时自动加载
|
| 430 |
+
"""
|
| 431 |
+
global model, USE_DEMO_MODE, MODEL_INIT_ATTEMPTED
|
| 432 |
|
| 433 |
+
# 如果已经初始化过模型,直接返回
|
| 434 |
+
if model is not None:
|
| 435 |
+
return model
|
| 436 |
+
|
| 437 |
+
# 如果已经尝试过初始化并失败,使用演示模式
|
| 438 |
+
if MODEL_INIT_ATTEMPTED and model is None:
|
| 439 |
+
logger.info("已尝试过初始化模型但失败,使用演示模式")
|
| 440 |
+
USE_DEMO_MODE = True
|
| 441 |
+
return None
|
| 442 |
+
|
| 443 |
+
# 标记为已尝试初始化
|
| 444 |
+
MODEL_INIT_ATTEMPTED = True
|
| 445 |
+
|
| 446 |
+
# 检查是否在Hugging Face Space环境中运行
|
| 447 |
+
if is_running_in_space():
|
| 448 |
+
logger.info("在Hugging Face Space环境中运行")
|
| 449 |
+
|
| 450 |
+
# 检查可用内存
|
| 451 |
+
available_memory = check_available_memory()
|
| 452 |
+
logger.info(f"可用内存: {available_memory:.2f}MB")
|
| 453 |
+
|
| 454 |
+
if available_memory < 8000: # 如果可用内存少于8GB
|
| 455 |
+
logger.warning(f"只有{available_memory:.2f}MB可用内存,可能不足以加载模型")
|
| 456 |
+
logger.info("使用演示模式以避免内存问题")
|
| 457 |
USE_DEMO_MODE = True
|
| 458 |
+
return None
|
| 459 |
|
| 460 |
+
if USE_DEMO_MODE:
|
| 461 |
+
logger.info("使用演示模式 - 不会加载模型")
|
| 462 |
+
return None # 在演示模式下使用模拟数据
|
| 463 |
+
|
| 464 |
+
try:
|
| 465 |
+
# 从环境变量获取token
|
| 466 |
+
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HF_TOCKEN")
|
| 467 |
+
|
| 468 |
+
logger.info(f"尝试加载模型 {MODEL_NAME}")
|
| 469 |
+
model = RadarDetectionModel(model_name=MODEL_NAME, use_auth_token=hf_token)
|
| 470 |
+
logger.info(f"成功加载模型 {MODEL_NAME}")
|
| 471 |
+
return model
|
| 472 |
+
except Exception as e:
|
| 473 |
+
logger.error(f"模型初始化错误: {str(e)}")
|
| 474 |
+
logger.info("由于模型加载错误,切换到演示模式")
|
| 475 |
+
USE_DEMO_MODE = True
|
| 476 |
+
return None
|
| 477 |
|
| 478 |
def create_confidence_chart(scores, labels):
|
| 479 |
"""Create a bar chart for confidence scores"""
|
| 480 |
if not scores or not labels:
|
| 481 |
return None
|
| 482 |
+
|
| 483 |
df = pd.DataFrame({
|
| 484 |
'Label': labels,
|
| 485 |
'Confidence': [score * 100 for score in scores]
|
| 486 |
})
|
| 487 |
+
|
| 488 |
fig = px.bar(
|
| 489 |
+
df,
|
| 490 |
+
x='Label',
|
| 491 |
y='Confidence',
|
| 492 |
title='Detection Confidence Scores',
|
| 493 |
labels={'Confidence': 'Confidence (%)'},
|
| 494 |
color='Confidence',
|
| 495 |
color_continuous_scale='viridis'
|
| 496 |
)
|
| 497 |
+
|
| 498 |
fig.update_layout(
|
| 499 |
xaxis_title='Detected Object',
|
| 500 |
yaxis_title='Confidence (%)',
|
| 501 |
yaxis_range=[0, 100],
|
| 502 |
template='plotly_white'
|
| 503 |
)
|
| 504 |
+
|
| 505 |
return fig
|
| 506 |
|
| 507 |
def create_feature_radar_chart(features):
|
| 508 |
"""Create a radar chart for feature analysis"""
|
| 509 |
categories = list(features.keys())
|
| 510 |
values = []
|
| 511 |
+
|
| 512 |
# Convert text classifications to numeric values (1-5 scale)
|
| 513 |
for feature in features.values():
|
| 514 |
if "High" in feature:
|
|
|
|
| 523 |
values.append(1)
|
| 524 |
else:
|
| 525 |
values.append(0)
|
| 526 |
+
|
| 527 |
fig = go.Figure()
|
| 528 |
+
|
| 529 |
fig.add_trace(go.Scatterpolar(
|
| 530 |
r=values,
|
| 531 |
theta=categories,
|
| 532 |
fill='toself',
|
| 533 |
name='Feature Analysis'
|
| 534 |
))
|
| 535 |
+
|
| 536 |
fig.update_layout(
|
| 537 |
polar=dict(
|
| 538 |
radialaxis=dict(
|
|
|
|
| 543 |
title='Feature Analysis Radar Chart',
|
| 544 |
template='plotly_white'
|
| 545 |
)
|
| 546 |
+
|
| 547 |
return fig
|
| 548 |
|
| 549 |
def create_heatmap(image_array):
|
| 550 |
"""Create a heatmap visualization of the image intensity"""
|
| 551 |
if image_array is None:
|
| 552 |
return None
|
| 553 |
+
|
| 554 |
# Convert to grayscale if needed
|
| 555 |
if len(image_array.shape) == 3 and image_array.shape[2] == 3:
|
| 556 |
gray_img = np.mean(image_array, axis=2)
|
| 557 |
else:
|
| 558 |
gray_img = image_array
|
| 559 |
+
|
| 560 |
fig = px.imshow(
|
| 561 |
gray_img,
|
| 562 |
color_continuous_scale='inferno',
|
| 563 |
title='Signal Intensity Heatmap'
|
| 564 |
)
|
| 565 |
+
|
| 566 |
fig.update_layout(
|
| 567 |
xaxis_title='X Position',
|
| 568 |
yaxis_title='Y Position',
|
| 569 |
template='plotly_white'
|
| 570 |
)
|
| 571 |
+
|
| 572 |
return fig
|
| 573 |
|
| 574 |
+
def cleanup_memory():
|
| 575 |
+
"""Attempt to clean up memory by forcing garbage collection"""
|
| 576 |
+
try:
|
| 577 |
+
import gc
|
| 578 |
+
gc.collect()
|
| 579 |
+
if torch.cuda.is_available():
|
| 580 |
+
torch.cuda.empty_cache()
|
| 581 |
+
print("Memory cleanup performed")
|
| 582 |
+
except Exception as e:
|
| 583 |
+
print(f"Error during memory cleanup: {str(e)}")
|
| 584 |
+
|
| 585 |
def process_image_streaming(image, generate_tech_report=False, progress=gr.Progress()):
|
| 586 |
+
"""处理图像并提供流式进度更新"""
|
| 587 |
if image is None:
|
| 588 |
+
raise gr.Error("请上传一张图像。")
|
| 589 |
|
| 590 |
+
# 仅在需要时初始化模型
|
| 591 |
+
progress(0.1, desc="初始化模型...")
|
| 592 |
+
log_memory_usage("在process_image中初始化模型之前")
|
| 593 |
global model, USE_DEMO_MODE
|
| 594 |
|
| 595 |
if not USE_DEMO_MODE:
|
| 596 |
+
model = initialize_model()
|
| 597 |
+
if model is None:
|
| 598 |
+
progress(0.15, desc="切换到演示模式...")
|
| 599 |
USE_DEMO_MODE = True
|
| 600 |
|
| 601 |
try:
|
| 602 |
+
# 如果需要,将图像转换为PIL Image
|
| 603 |
if isinstance(image, np.ndarray):
|
| 604 |
image = Image.fromarray(image)
|
| 605 |
|
| 606 |
+
# 运行检测
|
| 607 |
+
progress(0.2, desc="运行检测...")
|
| 608 |
+
log_memory_usage("检测之前")
|
| 609 |
|
| 610 |
if USE_DEMO_MODE:
|
| 611 |
+
# 在演示模式下使用模拟检测结果
|
| 612 |
detection_result = {
|
| 613 |
'boxes': [[100, 100, 200, 200], [300, 300, 400, 400]],
|
| 614 |
'scores': [0.92, 0.85],
|
| 615 |
+
'labels': ['裂缝', '腐蚀'],
|
| 616 |
'image': image
|
| 617 |
}
|
| 618 |
else:
|
| 619 |
+
try:
|
| 620 |
+
detection_result = model.detect(image)
|
| 621 |
+
log_memory_usage("检测之后")
|
| 622 |
+
except Exception as e:
|
| 623 |
+
logger.error(f"检测过程中出错: {str(e)}")
|
| 624 |
+
# 如果检测失败,切换到演示模式
|
| 625 |
+
USE_DEMO_MODE = True
|
| 626 |
+
detection_result = {
|
| 627 |
+
'boxes': [[100, 100, 200, 200], [300, 300, 400, 400]],
|
| 628 |
+
'scores': [0.92, 0.85],
|
| 629 |
+
'labels': ['错误', '备用'],
|
| 630 |
+
'image': image
|
| 631 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
|
| 633 |
+
# 提取特征
|
| 634 |
+
progress(0.3, desc="提取特征...")
|
| 635 |
+
features = extract_features(image, detection_result)
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
+
# 创建可视化图表
|
| 638 |
+
progress(0.5, desc="创建可视化...")
|
| 639 |
confidence_chart = create_confidence_chart(
|
| 640 |
+
detection_result.get('scores', []),
|
| 641 |
detection_result.get('labels', [])
|
| 642 |
)
|
| 643 |
|
| 644 |
feature_chart = create_feature_radar_chart(features)
|
| 645 |
+
heatmap = create_heatmap(np.array(image))
|
| 646 |
|
| 647 |
+
# 开始性能跟踪
|
| 648 |
+
progress(0.6, desc="分析性能...")
|
| 649 |
start_time = time.time()
|
| 650 |
performance_data = {
|
| 651 |
'pipeline_stats': {},
|
|
|
|
| 653 |
'gpu_util': 0
|
| 654 |
}
|
| 655 |
|
| 656 |
+
# 处理图像并获取结果
|
| 657 |
stage_start = time.time()
|
| 658 |
detection_results = detection_result
|
| 659 |
detection_results['processing_time'] = (time.time() - stage_start) * 1000
|
|
|
|
| 662 |
'memory': get_memory_usage()
|
| 663 |
}
|
| 664 |
|
| 665 |
+
# 提取特征并分析
|
| 666 |
stage_start = time.time()
|
| 667 |
model_outputs = {
|
| 668 |
'feature_quality': 0.85,
|
|
|
|
| 677 |
'memory': get_memory_usage()
|
| 678 |
}
|
| 679 |
|
| 680 |
+
# 执行多模态分析
|
| 681 |
stage_start = time.time()
|
| 682 |
multimodal_results = {
|
| 683 |
'alignment_score': 0.78,
|
|
|
|
| 689 |
'memory': get_memory_usage()
|
| 690 |
}
|
| 691 |
|
| 692 |
+
# 更新性能数据
|
| 693 |
performance_data['total_time'] = (time.time() - start_time) * 1000
|
| 694 |
performance_data['peak_memory'] = get_peak_memory_usage()
|
| 695 |
performance_data['gpu_util'] = get_gpu_utilization()
|
| 696 |
|
| 697 |
+
# 生成分析报告
|
| 698 |
+
progress(0.8, desc="生成报告...")
|
| 699 |
analysis_report = generate_report(detection_result, features)
|
| 700 |
|
| 701 |
+
# 准备输出
|
| 702 |
output_image = plot_detection(image, detection_result)
|
| 703 |
|
| 704 |
if generate_tech_report:
|
| 705 |
+
# 准备技术报告的数据
|
| 706 |
tech_report_data = {
|
| 707 |
'model_outputs': model_outputs,
|
| 708 |
'detection_results': detection_results,
|
|
|
|
| 710 |
'performance_data': performance_data
|
| 711 |
}
|
| 712 |
|
| 713 |
+
# 生成技术报告
|
| 714 |
tech_report = TechnicalReportGenerator().generate_report(tech_report_data)
|
| 715 |
|
| 716 |
+
# 将技术报告保存到临时文件
|
| 717 |
report_path = "temp_tech_report.md"
|
| 718 |
with open(report_path, "w") as f:
|
| 719 |
f.write(tech_report)
|
| 720 |
|
| 721 |
+
progress(1.0, desc="分析完成!")
|
| 722 |
+
# 处理完成后清理内存
|
| 723 |
+
cleanup_memory()
|
| 724 |
return output_image, analysis_report, report_path, confidence_chart, feature_chart, heatmap
|
| 725 |
|
| 726 |
+
progress(1.0, desc="分析完成!")
|
| 727 |
+
# 处理完成后清理内存
|
| 728 |
+
cleanup_memory()
|
| 729 |
return output_image, analysis_report, None, confidence_chart, feature_chart, heatmap
|
| 730 |
|
| 731 |
except Exception as e:
|
| 732 |
+
error_msg = f"处理图像时出错: {str(e)}"
|
| 733 |
print(error_msg)
|
| 734 |
+
# 出错后清理内存
|
| 735 |
+
cleanup_memory()
|
| 736 |
raise gr.Error(error_msg)
|
| 737 |
|
| 738 |
def display_history():
|
|
|
|
| 756 |
def get_memory_usage():
|
| 757 |
"""Get current memory usage in MB"""
|
| 758 |
process = psutil.Process()
|
| 759 |
+
memory_info = process.memory_info()
|
| 760 |
+
return memory_info.rss / 1024 / 1024
|
| 761 |
|
| 762 |
def get_peak_memory_usage():
|
| 763 |
"""Get peak memory usage in MB"""
|
| 764 |
+
try:
|
| 765 |
+
process = psutil.Process()
|
| 766 |
+
memory_info = process.memory_info()
|
| 767 |
+
if hasattr(memory_info, 'peak_wset'):
|
| 768 |
+
return memory_info.peak_wset / 1024 / 1024
|
| 769 |
+
else:
|
| 770 |
+
# On Linux, we can use /proc/self/status to get peak memory
|
| 771 |
+
with open('/proc/self/status') as f:
|
| 772 |
+
for line in f:
|
| 773 |
+
if line.startswith('VmHWM:'):
|
| 774 |
+
return float(line.split()[1]) / 1024 # Convert KB to MB
|
| 775 |
+
except:
|
| 776 |
+
pass
|
| 777 |
+
return 0
|
| 778 |
|
| 779 |
def get_gpu_utilization():
|
| 780 |
"""Get GPU utilization percentage"""
|
|
|
|
| 785 |
pass
|
| 786 |
return 0
|
| 787 |
|
| 788 |
+
def log_memory_usage(stage=""):
|
| 789 |
+
"""Log current memory usage"""
|
| 790 |
+
mem_usage = get_memory_usage()
|
| 791 |
+
peak_mem = get_peak_memory_usage()
|
| 792 |
+
gpu_util = get_gpu_utilization()
|
| 793 |
+
print(f"Memory usage at {stage}: {mem_usage:.2f}MB (Peak: {peak_mem:.2f}MB, GPU: {gpu_util:.2f}%)")
|
| 794 |
+
|
| 795 |
def toggle_dark_mode():
|
| 796 |
"""Toggle between light and dark themes"""
|
| 797 |
global DARK_MODE
|
| 798 |
DARK_MODE = not DARK_MODE
|
| 799 |
return gr.Theme.darkmode() if DARK_MODE else THEME
|
| 800 |
|
| 801 |
+
def get_space_upgrade_url():
|
| 802 |
+
"""Get the URL for upgrading the Space"""
|
| 803 |
+
if not is_running_in_space():
|
| 804 |
+
return "#"
|
| 805 |
+
|
| 806 |
+
space_id = os.environ.get("SPACE_ID", "")
|
| 807 |
+
if not space_id:
|
| 808 |
+
return "https://huggingface.co/pricing"
|
| 809 |
+
|
| 810 |
+
# Extract username and space name
|
| 811 |
+
parts = space_id.split("/")
|
| 812 |
+
if len(parts) != 2:
|
| 813 |
+
return "https://huggingface.co/pricing"
|
| 814 |
+
|
| 815 |
+
username, space_name = parts
|
| 816 |
+
return f"https://huggingface.co/spaces/{username}/{space_name}/settings"
|
| 817 |
+
|
| 818 |
+
def get_local_installation_instructions():
|
| 819 |
+
"""Get instructions for running the app locally"""
|
| 820 |
+
required_memory = estimate_model_memory_requirements()
|
| 821 |
+
repo_url = get_repository_url()
|
| 822 |
+
|
| 823 |
+
return f"""
|
| 824 |
+
## Running Locally
|
| 825 |
+
|
| 826 |
+
To run this application locally with the full model:
|
| 827 |
+
|
| 828 |
+
1. Clone the repository:
|
| 829 |
+
```bash
|
| 830 |
+
git clone {repo_url}
|
| 831 |
+
cd radar-analysis
|
| 832 |
+
```
|
| 833 |
+
|
| 834 |
+
2. Install dependencies:
|
| 835 |
+
```bash
|
| 836 |
+
pip install -r requirements.txt
|
| 837 |
+
```
|
| 838 |
+
|
| 839 |
+
3. Set your Hugging Face token as an environment variable:
|
| 840 |
+
```bash
|
| 841 |
+
export HF_TOCKEN=your_huggingface_token
|
| 842 |
+
```
|
| 843 |
+
|
| 844 |
+
4. Run the application:
|
| 845 |
+
```bash
|
| 846 |
+
python app.py
|
| 847 |
+
```
|
| 848 |
+
|
| 849 |
+
Make sure your system has at least {required_memory/1024:.1f}GB of RAM for optimal performance.
|
| 850 |
+
"""
|
| 851 |
+
|
| 852 |
+
def get_model_card_url():
|
| 853 |
+
"""Get the URL for the model card"""
|
| 854 |
+
return f"https://huggingface.co/{MODEL_NAME}"
|
| 855 |
+
|
| 856 |
+
def has_enough_memory_for_model():
|
| 857 |
+
"""Check if we have enough memory for the model"""
|
| 858 |
+
if is_running_in_space():
|
| 859 |
+
# In Spaces, we need to be more cautious
|
| 860 |
+
hardware_memory = get_space_hardware_memory() * 1024 # Convert GB to MB
|
| 861 |
+
required_memory = estimate_model_memory_requirements()
|
| 862 |
+
print(f"Space hardware memory: {hardware_memory}MB, Required: {required_memory:.2f}MB")
|
| 863 |
+
return hardware_memory >= required_memory
|
| 864 |
+
else:
|
| 865 |
+
# For local development, check available memory
|
| 866 |
+
available_memory = check_available_memory()
|
| 867 |
+
required_memory = estimate_model_memory_requirements()
|
| 868 |
+
print(f"Available memory: {available_memory:.2f}MB, Required: {required_memory:.2f}MB")
|
| 869 |
+
return available_memory >= required_memory
|
| 870 |
+
|
| 871 |
+
def get_repository_url():
|
| 872 |
+
"""Get the URL for the repository"""
|
| 873 |
+
if is_running_in_space():
|
| 874 |
+
space_id = os.environ.get("SPACE_ID", "")
|
| 875 |
+
if space_id:
|
| 876 |
+
# Space ID is in the format "username/spacename"
|
| 877 |
+
return f"https://huggingface.co/spaces/{space_id}"
|
| 878 |
+
else:
|
| 879 |
+
return "https://huggingface.co/spaces/xingqiang/radar-analysis"
|
| 880 |
+
else:
|
| 881 |
+
return "https://huggingface.co/spaces/xingqiang/radar-analysis"
|
| 882 |
+
|
| 883 |
+
def get_directory_name_from_repo_url(repo_url):
|
| 884 |
+
"""Get the directory name from the repository URL"""
|
| 885 |
+
# Extract the last part of the URL
|
| 886 |
+
parts = repo_url.rstrip('/').split('/')
|
| 887 |
+
return parts[-1]
|
| 888 |
+
|
| 889 |
+
# Launch the interface
|
| 890 |
+
def launch():
|
| 891 |
+
"""启动Gradio界面"""
|
| 892 |
+
if is_running_in_space():
|
| 893 |
+
# 在Spaces中,使用最小资源配置以避免内存问题
|
| 894 |
+
logger.info("使用最小资源配置启动Spaces")
|
| 895 |
+
iface.launch(
|
| 896 |
+
share=False,
|
| 897 |
+
server_name="0.0.0.0",
|
| 898 |
+
server_port=7860,
|
| 899 |
+
max_threads=4, # 从10减少到4
|
| 900 |
+
show_error=True,
|
| 901 |
+
quiet=False
|
| 902 |
+
)
|
| 903 |
+
else:
|
| 904 |
+
# 对于本地开发,使用默认设置
|
| 905 |
+
iface.launch()
|
| 906 |
+
|
| 907 |
# Create Gradio interface
|
| 908 |
with gr.Blocks(theme=THEME) as iface:
|
| 909 |
theme_state = gr.State(THEME)
|
| 910 |
+
|
| 911 |
with gr.Row():
|
| 912 |
+
gr.Markdown("# 雷达图像分析系统")
|
| 913 |
+
dark_mode_btn = gr.Button("🌓 切换暗黑模式", scale=0)
|
| 914 |
+
|
| 915 |
+
# 添加模型加载提示
|
| 916 |
+
gr.Markdown("""
|
| 917 |
+
### ℹ️ 模型加载说明
|
| 918 |
+
- 模型仅在您点击"分析"按钮时才会下载和初始化
|
| 919 |
+
- 首次分析可能需要较长时间,因为需要下载模型
|
| 920 |
+
- 如果内存不足,系统会自动切换到演示模式
|
| 921 |
+
""", elem_id="model-loading-notice")
|
| 922 |
+
|
| 923 |
if USE_DEMO_MODE:
|
| 924 |
+
hardware_type = get_space_hardware_type() if is_running_in_space() else "N/A"
|
| 925 |
+
hardware_tier = get_space_hardware_tier() if is_running_in_space() else "N/A"
|
| 926 |
+
hardware_memory = get_space_hardware_memory() if is_running_in_space() else 0
|
| 927 |
+
total_memory = get_total_system_memory()
|
| 928 |
+
required_memory = estimate_model_memory_requirements()
|
| 929 |
+
recommended_tier = get_recommended_space_tier()
|
| 930 |
+
upgrade_url = get_space_upgrade_url()
|
| 931 |
+
model_card_url = get_model_card_url()
|
| 932 |
+
|
| 933 |
+
memory_info = f"Space硬件: {hardware_type} (等级: {hardware_tier}, 内存: {hardware_memory}GB)"
|
| 934 |
+
model_req = f"[PaliGemma模型]({model_card_url})在使用8位量化加载时需要约{required_memory/1024:.1f}GB内存"
|
| 935 |
+
|
| 936 |
+
gr.Markdown(f"""
|
| 937 |
+
### ⚠️ 运行在演示模式
|
| 938 |
+
由于内存限制,应用程序当前在演示模式下运行:
|
| 939 |
+
|
| 940 |
+
1. **内存错误**: Space遇到"内存限制超过(16Gi)"错误
|
| 941 |
+
- {memory_info}
|
| 942 |
+
- 系统总内存: {total_memory:.2f}MB
|
| 943 |
+
- {model_req}
|
| 944 |
+
|
| 945 |
+
2. **解决方案**:
|
| 946 |
+
- 演示模式提供模拟结果用于演示目的
|
| 947 |
+
- 要使用完整模型,请在本地运行此应用程序,需要{required_memory/1024:.1f}GB+内存
|
| 948 |
+
- 或[升级到{recommended_tier} Space等级]({upgrade_url})或更高
|
| 949 |
+
|
| 950 |
+
演示模式仍提供所有UI功能和可视化特性。
|
| 951 |
""", elem_id="demo-mode-warning")
|
| 952 |
+
|
| 953 |
+
gr.Markdown("上传雷达图像以分析缺陷并生成技术报告")
|
| 954 |
+
|
| 955 |
with gr.Tabs() as tabs:
|
| 956 |
+
with gr.TabItem("分析", id="analysis"):
|
| 957 |
with gr.Row():
|
| 958 |
with gr.Column(scale=1):
|
| 959 |
+
with gr.Accordion("输入", open=True):
|
| 960 |
input_image = gr.Image(
|
| 961 |
+
type="pil",
|
| 962 |
+
label="上传雷达图像",
|
| 963 |
elem_id="input-image",
|
| 964 |
sources=["upload", "webcam", "clipboard"],
|
| 965 |
tool="editor"
|
| 966 |
)
|
| 967 |
tech_report_checkbox = gr.Checkbox(
|
| 968 |
+
label="生成技术报告",
|
| 969 |
value=False,
|
| 970 |
+
info="创建详细的技术分析报告"
|
| 971 |
)
|
| 972 |
analyze_button = gr.Button(
|
| 973 |
+
"分析",
|
| 974 |
variant="primary",
|
| 975 |
elem_id="analyze-btn"
|
| 976 |
)
|
| 977 |
+
|
| 978 |
with gr.Column(scale=2):
|
| 979 |
+
with gr.Accordion("检测结果", open=True):
|
| 980 |
output_image = gr.Image(
|
| 981 |
+
type="pil",
|
| 982 |
+
label="检测结果",
|
| 983 |
elem_id="output-image"
|
| 984 |
)
|
| 985 |
+
|
| 986 |
+
with gr.Accordion("分析报告", open=True):
|
| 987 |
output_report = gr.HTML(
|
| 988 |
+
label="分析报告",
|
| 989 |
elem_id="analysis-report"
|
| 990 |
)
|
| 991 |
tech_report_output = gr.File(
|
| 992 |
+
label="技术报告",
|
| 993 |
elem_id="tech-report"
|
| 994 |
)
|
| 995 |
+
|
| 996 |
with gr.Row():
|
| 997 |
with gr.Column():
|
| 998 |
confidence_plot = gr.Plot(
|
| 999 |
+
label="置信度分数",
|
| 1000 |
elem_id="confidence-plot"
|
| 1001 |
)
|
| 1002 |
+
|
| 1003 |
with gr.Column():
|
| 1004 |
feature_plot = gr.Plot(
|
| 1005 |
+
label="特征分析",
|
| 1006 |
elem_id="feature-plot"
|
| 1007 |
)
|
| 1008 |
+
|
| 1009 |
with gr.Row():
|
| 1010 |
heatmap_plot = gr.Plot(
|
| 1011 |
+
label="信号强度热图",
|
| 1012 |
elem_id="heatmap-plot"
|
| 1013 |
)
|
| 1014 |
+
|
| 1015 |
+
with gr.TabItem("历史", id="history"):
|
| 1016 |
with gr.Row():
|
| 1017 |
+
history_button = gr.Button("刷新历史")
|
| 1018 |
history_output = gr.HTML(elem_id="history-output")
|
| 1019 |
+
|
| 1020 |
+
with gr.TabItem("帮助", id="help"):
|
| 1021 |
gr.Markdown("""
|
| 1022 |
+
## 如何使用此工具
|
| 1023 |
+
|
| 1024 |
+
1. **上传图像**: 点击上传按钮选择要分析的雷达图像
|
| 1025 |
+
2. **生成技术报告** (可选): 如果需要详细的技术报告,请勾选此框
|
| 1026 |
+
3. **分析**: 点击分析按钮处理图像
|
| 1027 |
+
4. **查看结果**:
|
| 1028 |
+
- 检测可视化显示已识别的缺陷
|
| 1029 |
+
- 分析报告提供发现的摘要
|
| 1030 |
+
- 技术报告(如果请求)提供详细指标
|
| 1031 |
+
- 图表提供置信度分数和特征分析的可视化表示
|
| 1032 |
+
|
| 1033 |
+
## 关于模型
|
| 1034 |
+
|
| 1035 |
+
该系统使用[PaliGemma]({get_model_card_url()}),这是一个视觉-语言模型,结合了SigLIP-So400m(图像编码器)和Gemma-2B(文本解码器)进行联合目标检测和多模态分析。
|
| 1036 |
+
|
| 1037 |
+
该模型针对雷达图像分析进行了微调,可以检测结构检查图像中的各种类型的缺陷和异常。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1038 |
""")
|
| 1039 |
+
|
| 1040 |
+
if USE_DEMO_MODE and is_running_in_space():
|
| 1041 |
+
gr.Markdown(get_local_installation_instructions())
|
| 1042 |
+
|
| 1043 |
+
gr.Markdown("""
|
| 1044 |
+
## 键盘快捷键
|
| 1045 |
+
|
| 1046 |
+
- **Ctrl+A**: 触发分析
|
| 1047 |
+
- **Ctrl+D**: 切换暗黑模式
|
| 1048 |
+
|
| 1049 |
+
## 故障排除
|
| 1050 |
+
|
| 1051 |
+
- 如果分析失败,请尝试上传不同的图像格式
|
| 1052 |
+
- 确保图像是有效的雷达扫描
|
| 1053 |
+
- 对于技术问题,请查看控制台日志
|
| 1054 |
+
""")
|
| 1055 |
+
|
| 1056 |
# Set up event handlers
|
| 1057 |
dark_mode_btn.click(
|
| 1058 |
fn=toggle_dark_mode,
|
|
|
|
| 1060 |
outputs=[iface],
|
| 1061 |
api_name="toggle_theme"
|
| 1062 |
)
|
| 1063 |
+
|
| 1064 |
analyze_button.click(
|
| 1065 |
fn=process_image_streaming,
|
| 1066 |
inputs=[input_image, tech_report_checkbox],
|
| 1067 |
outputs=[output_image, output_report, tech_report_output, confidence_plot, feature_plot, heatmap_plot],
|
| 1068 |
api_name="analyze"
|
| 1069 |
)
|
| 1070 |
+
|
| 1071 |
history_button.click(
|
| 1072 |
fn=display_history,
|
| 1073 |
inputs=[],
|
| 1074 |
outputs=[history_output],
|
| 1075 |
api_name="history"
|
| 1076 |
)
|
| 1077 |
+
|
| 1078 |
# Add keyboard shortcuts
|
| 1079 |
iface.load(lambda: None, None, None, _js="""
|
| 1080 |
() => {
|
|
|
|
| 1083 |
document.getElementById('analyze-btn').click();
|
| 1084 |
}
|
| 1085 |
if (e.key === 'd' && e.ctrlKey) {
|
| 1086 |
+
document.querySelector('button:contains("切换暗黑模式")').click();
|
| 1087 |
}
|
| 1088 |
});
|
| 1089 |
}
|
| 1090 |
""")
|
| 1091 |
|
| 1092 |
# Launch the interface
|
| 1093 |
+
launch()
|
config.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
|
| 3 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
|
| 5 |
-
MODEL_NAME = "Extremely4606/
|
| 6 |
DATABASE_URL = f"sqlite:///{os.path.join(BASE_DIR, 'radar_reports.db')}"
|
| 7 |
|
| 8 |
AMPLITUDE_THRESHOLD = 128
|
|
|
|
| 2 |
|
| 3 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
|
| 5 |
+
MODEL_NAME = "Extremely4606/paligemma24_12_30"
|
| 6 |
DATABASE_URL = f"sqlite:///{os.path.join(BASE_DIR, 'radar_reports.db')}"
|
| 7 |
|
| 8 |
AMPLITUDE_THRESHOLD = 128
|
create_space.py
CHANGED
|
@@ -4,13 +4,23 @@ import sys
|
|
| 4 |
|
| 5 |
def create_and_push_space():
|
| 6 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
print("Creating Space...")
|
| 8 |
# Create the space
|
|
|
|
| 9 |
repo_url = create_repo(
|
| 10 |
repo_id="xingqiang/radar-analysis",
|
| 11 |
repo_type="space",
|
| 12 |
space_sdk="gradio",
|
| 13 |
-
private=False
|
|
|
|
| 14 |
)
|
| 15 |
print(f"Space created successfully at: {repo_url}")
|
| 16 |
|
|
|
|
| 4 |
|
| 5 |
def create_and_push_space():
|
| 6 |
try:
|
| 7 |
+
# Get Hugging Face token from environment
|
| 8 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 9 |
+
if not hf_token:
|
| 10 |
+
print("Error: HF_TOKEN environment variable not set")
|
| 11 |
+
print("Please set your Hugging Face token as an environment variable:")
|
| 12 |
+
print("export HF_TOKEN=your_token_here")
|
| 13 |
+
sys.exit(1)
|
| 14 |
+
|
| 15 |
print("Creating Space...")
|
| 16 |
# Create the space
|
| 17 |
+
api = HfApi(token=hf_token)
|
| 18 |
repo_url = create_repo(
|
| 19 |
repo_id="xingqiang/radar-analysis",
|
| 20 |
repo_type="space",
|
| 21 |
space_sdk="gradio",
|
| 22 |
+
private=False,
|
| 23 |
+
token=hf_token
|
| 24 |
)
|
| 25 |
print(f"Space created successfully at: {repo_url}")
|
| 26 |
|
feature_extraction.py
CHANGED
|
@@ -40,13 +40,30 @@ def classify_reflections(count):
|
|
| 40 |
|
| 41 |
|
| 42 |
def extract_features(image, detection_result):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
np_image = np.array(image)
|
| 44 |
amplitude = calculate_amplitude(np_image)
|
| 45 |
amplitude_class = classify_amplitude(amplitude)
|
| 46 |
|
| 47 |
-
box
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
attenuation_rate = calculate_attenuation_rate(np_image)
|
| 52 |
attenuation_class = classify_attenuation_rate(attenuation_rate)
|
|
@@ -59,8 +76,8 @@ def extract_features(image, detection_result):
|
|
| 59 |
"分布范围": distribution_class,
|
| 60 |
"衰减速度": attenuation_class,
|
| 61 |
"反射次数": reflection_class,
|
| 62 |
-
"振幅值": amplitude,
|
| 63 |
-
"分布范围值": distribution_range,
|
| 64 |
-
"衰减速度值": attenuation_rate,
|
| 65 |
-
"反射次数值": reflection_count
|
| 66 |
}
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
def extract_features(image, detection_result):
|
| 43 |
+
"""
|
| 44 |
+
Extract features from the image and detection result.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
image: PIL Image
|
| 48 |
+
detection_result: Dictionary containing detection results
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Dictionary of features
|
| 52 |
+
"""
|
| 53 |
np_image = np.array(image)
|
| 54 |
amplitude = calculate_amplitude(np_image)
|
| 55 |
amplitude_class = classify_amplitude(amplitude)
|
| 56 |
|
| 57 |
+
# Handle box calculation
|
| 58 |
+
if detection_result and 'boxes' in detection_result and len(detection_result['boxes']) > 0:
|
| 59 |
+
box = detection_result['boxes'][0]
|
| 60 |
+
if not isinstance(box, list):
|
| 61 |
+
box = box.tolist()
|
| 62 |
+
distribution_range = calculate_distribution_range(box)
|
| 63 |
+
distribution_class = classify_distribution_range(distribution_range)
|
| 64 |
+
else:
|
| 65 |
+
distribution_range = 0
|
| 66 |
+
distribution_class = "小"
|
| 67 |
|
| 68 |
attenuation_rate = calculate_attenuation_rate(np_image)
|
| 69 |
attenuation_class = classify_attenuation_rate(attenuation_rate)
|
|
|
|
| 76 |
"分布范围": distribution_class,
|
| 77 |
"衰减速度": attenuation_class,
|
| 78 |
"反射次数": reflection_class,
|
| 79 |
+
"振幅值": float(amplitude),
|
| 80 |
+
"分布范围值": float(distribution_range),
|
| 81 |
+
"衰减速度值": float(attenuation_rate),
|
| 82 |
+
"反射次数值": int(reflection_count)
|
| 83 |
}
|
model.py
CHANGED
|
@@ -6,101 +6,205 @@ import logging
|
|
| 6 |
from transformers import AutoProcessor, AutoModelForVision2Seq
|
| 7 |
from PIL import Image
|
| 8 |
import numpy as np
|
|
|
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
class RadarDetectionModel:
|
| 13 |
-
def __init__(self, model_name=
|
| 14 |
"""
|
| 15 |
-
|
| 16 |
|
| 17 |
Args:
|
| 18 |
-
model_name (str):
|
| 19 |
-
use_auth_token (str, optional): Hugging Face
|
| 20 |
"""
|
| 21 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
else:
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def detect(self, image):
|
| 35 |
"""
|
| 36 |
-
|
| 37 |
|
| 38 |
Args:
|
| 39 |
-
image (PIL.Image):
|
| 40 |
|
| 41 |
Returns:
|
| 42 |
-
dict:
|
| 43 |
"""
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
def _parse_detection_results(self, text, image_size):
|
| 71 |
"""
|
| 72 |
-
|
| 73 |
|
| 74 |
Args:
|
| 75 |
-
text (str):
|
| 76 |
-
image_size (tuple):
|
| 77 |
|
| 78 |
Returns:
|
| 79 |
tuple: (boxes, scores, labels)
|
| 80 |
"""
|
| 81 |
-
#
|
| 82 |
-
#
|
| 83 |
|
| 84 |
-
#
|
| 85 |
defects = []
|
| 86 |
|
| 87 |
-
if "crack" in text.lower():
|
| 88 |
-
defects.append(("
|
| 89 |
|
| 90 |
-
if "corrosion" in text.lower():
|
| 91 |
-
defects.append(("
|
| 92 |
|
| 93 |
-
if "damage" in text.lower():
|
| 94 |
-
defects.append(("
|
| 95 |
|
| 96 |
-
if "defect" in text.lower():
|
| 97 |
-
defects.append(("
|
| 98 |
|
| 99 |
-
#
|
| 100 |
if not defects:
|
| 101 |
-
defects.append(("
|
| 102 |
|
| 103 |
-
#
|
| 104 |
width, height = image_size
|
| 105 |
boxes = []
|
| 106 |
scores = []
|
|
|
|
| 6 |
from transformers import AutoProcessor, AutoModelForVision2Seq
|
| 7 |
from PIL import Image
|
| 8 |
import numpy as np
|
| 9 |
+
from config import MODEL_NAME
|
| 10 |
|
| 11 |
+
# 配置日志记录
|
| 12 |
+
logging.basicConfig(level=logging.INFO,
|
| 13 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
class RadarDetectionModel:
|
| 17 |
+
def __init__(self, model_name=None, use_auth_token=None):
|
| 18 |
"""
|
| 19 |
+
初始化雷达检测模型。
|
| 20 |
|
| 21 |
Args:
|
| 22 |
+
model_name (str): 要加载的模型名称或路径
|
| 23 |
+
use_auth_token (str, optional): 用于访问受限模型的Hugging Face令牌
|
| 24 |
"""
|
| 25 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
logger.info(f"使用设备: {self.device}")
|
| 27 |
|
| 28 |
+
self.model_name = model_name if model_name else MODEL_NAME
|
| 29 |
+
logger.info(f"模型名称: {self.model_name}")
|
| 30 |
+
|
| 31 |
+
self.use_auth_token = use_auth_token or os.environ.get("HF_TOKEN")
|
| 32 |
+
if self.use_auth_token:
|
| 33 |
+
logger.info("已提供Hugging Face令牌")
|
| 34 |
else:
|
| 35 |
+
logger.warning("未提供Hugging Face令牌,可能无法访问受限模型")
|
| 36 |
+
|
| 37 |
+
self.processor = None
|
| 38 |
+
self.model = None
|
| 39 |
+
|
| 40 |
+
# 加载模型和处理器
|
| 41 |
+
logger.info("开始加载模型和处理器...")
|
| 42 |
+
self._load_model()
|
| 43 |
+
|
| 44 |
+
def _load_model(self):
|
| 45 |
+
"""加载模型和处理器,并监控内存使用情况"""
|
| 46 |
+
try:
|
| 47 |
+
logger.info(f"正在从{self.model_name}加载处理器")
|
| 48 |
+
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
| 49 |
+
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
| 50 |
+
|
| 51 |
+
if start_time:
|
| 52 |
+
start_time.record()
|
| 53 |
+
|
| 54 |
+
if self.use_auth_token:
|
| 55 |
+
# 如果提供了令牌,登录到Hugging Face Hub
|
| 56 |
+
logger.info("使用令牌登录到Hugging Face Hub")
|
| 57 |
+
login(token=self.use_auth_token)
|
| 58 |
+
self.processor = AutoProcessor.from_pretrained(self.model_name, use_auth_token=self.use_auth_token)
|
| 59 |
+
else:
|
| 60 |
+
self.processor = AutoProcessor.from_pretrained(self.model_name)
|
| 61 |
+
|
| 62 |
+
if end_time:
|
| 63 |
+
end_time.record()
|
| 64 |
+
torch.cuda.synchronize()
|
| 65 |
+
logger.info(f"处理器加载时间: {start_time.elapsed_time(end_time):.2f}毫秒")
|
| 66 |
+
|
| 67 |
+
logger.info(f"正在从{self.model_name}加载模型,使用8位量化以减少内存使用")
|
| 68 |
|
| 69 |
+
if start_time:
|
| 70 |
+
start_time.record()
|
| 71 |
+
|
| 72 |
+
# 使用8位量化以减少内存使用
|
| 73 |
+
if self.use_auth_token:
|
| 74 |
+
self.model = AutoModelForVision2Seq.from_pretrained(
|
| 75 |
+
self.model_name,
|
| 76 |
+
use_auth_token=self.use_auth_token,
|
| 77 |
+
load_in_8bit=True, # 使用8位量化
|
| 78 |
+
device_map="auto" # 自动管理设备放置
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
self.model = AutoModelForVision2Seq.from_pretrained(
|
| 82 |
+
self.model_name,
|
| 83 |
+
load_in_8bit=True, # 使用8位量化
|
| 84 |
+
device_map="auto" # 自动管理设备放置
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if end_time:
|
| 88 |
+
end_time.record()
|
| 89 |
+
torch.cuda.synchronize()
|
| 90 |
+
logger.info(f"模型加载时间: {start_time.elapsed_time(end_time):.2f}毫秒")
|
| 91 |
+
|
| 92 |
+
logger.info(f"模型加载成功")
|
| 93 |
+
# 使用device_map="auto"时无需手动移动到设备
|
| 94 |
+
self.model.eval()
|
| 95 |
+
|
| 96 |
+
# 记录模型信息
|
| 97 |
+
param_count = sum(p.numel() for p in self.model.parameters())
|
| 98 |
+
logger.info(f"模型参数数量: {param_count:,}")
|
| 99 |
+
|
| 100 |
+
if torch.cuda.is_available():
|
| 101 |
+
memory_allocated = torch.cuda.memory_allocated() / (1024 * 1024)
|
| 102 |
+
memory_reserved = torch.cuda.memory_reserved() / (1024 * 1024)
|
| 103 |
+
logger.info(f"GPU内存分配: {memory_allocated:.2f}MB")
|
| 104 |
+
logger.info(f"GPU内存保留: {memory_reserved:.2f}MB")
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
logger.error(f"加载模型时出错: {str(e)}")
|
| 108 |
+
raise
|
| 109 |
|
| 110 |
def detect(self, image):
|
| 111 |
"""
|
| 112 |
+
检测雷达图像中的对象。
|
| 113 |
|
| 114 |
Args:
|
| 115 |
+
image (PIL.Image): 要分析的雷达图像
|
| 116 |
|
| 117 |
Returns:
|
| 118 |
+
dict: 检测结果,包括边界框、分数和标签
|
| 119 |
"""
|
| 120 |
+
try:
|
| 121 |
+
if self.model is None or self.processor is None:
|
| 122 |
+
raise ValueError("模型或处理器未正确初始化")
|
| 123 |
+
|
| 124 |
+
# 预处理图像
|
| 125 |
+
logger.info("预处理图像")
|
| 126 |
+
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
| 127 |
+
|
| 128 |
+
# 运行推理
|
| 129 |
+
logger.info("运行模型推理")
|
| 130 |
+
start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
| 131 |
+
end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
|
| 132 |
+
|
| 133 |
+
if start_time:
|
| 134 |
+
start_time.record()
|
| 135 |
+
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
outputs = self.model.generate(
|
| 138 |
+
**inputs,
|
| 139 |
+
max_length=50,
|
| 140 |
+
num_beams=4,
|
| 141 |
+
early_stopping=True
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if end_time:
|
| 145 |
+
end_time.record()
|
| 146 |
+
torch.cuda.synchronize()
|
| 147 |
+
inference_time = start_time.elapsed_time(end_time)
|
| 148 |
+
logger.info(f"推理时间: {inference_time:.2f}毫秒")
|
| 149 |
+
|
| 150 |
+
# 处理输出
|
| 151 |
+
generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 152 |
+
logger.info(f"生成的文本: {generated_text}")
|
| 153 |
+
|
| 154 |
+
# 从生成的文本中解析检测结果
|
| 155 |
+
boxes, scores, labels = self._parse_detection_results(generated_text, image.size)
|
| 156 |
+
logger.info(f"检测到{len(boxes)}个对象")
|
| 157 |
+
|
| 158 |
+
return {
|
| 159 |
+
'boxes': boxes,
|
| 160 |
+
'scores': scores,
|
| 161 |
+
'labels': labels,
|
| 162 |
+
'image': image
|
| 163 |
+
}
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.error(f"检测过程中出错: {str(e)}")
|
| 166 |
+
# 返回备用检测结果
|
| 167 |
+
return {
|
| 168 |
+
'boxes': [[100, 100, 200, 200]],
|
| 169 |
+
'scores': [0.75],
|
| 170 |
+
'labels': ['错误: ' + str(e)[:50]],
|
| 171 |
+
'image': image
|
| 172 |
+
}
|
| 173 |
|
| 174 |
def _parse_detection_results(self, text, image_size):
|
| 175 |
"""
|
| 176 |
+
从生成的文本中解析检测结果。
|
| 177 |
|
| 178 |
Args:
|
| 179 |
+
text (str): 模型生成的文本
|
| 180 |
+
image_size (tuple): 输入图像的大小(宽度, 高度)
|
| 181 |
|
| 182 |
Returns:
|
| 183 |
tuple: (boxes, scores, labels)
|
| 184 |
"""
|
| 185 |
+
# 这是一个简化的示例 - 实际解析将取决于模型输出格式
|
| 186 |
+
# 为了演示,我们将提取一些模拟检测结果
|
| 187 |
|
| 188 |
+
# 检查文本中常见的缺陷关键词
|
| 189 |
defects = []
|
| 190 |
|
| 191 |
+
if "crack" in text.lower() or "裂缝" in text.lower():
|
| 192 |
+
defects.append(("裂缝", 0.92, [0.2, 0.3, 0.4, 0.5]))
|
| 193 |
|
| 194 |
+
if "corrosion" in text.lower() or "腐蚀" in text.lower():
|
| 195 |
+
defects.append(("腐蚀", 0.85, [0.6, 0.2, 0.8, 0.4]))
|
| 196 |
|
| 197 |
+
if "damage" in text.lower() or "损坏" in text.lower():
|
| 198 |
+
defects.append(("损坏", 0.78, [0.1, 0.7, 0.3, 0.9]))
|
| 199 |
|
| 200 |
+
if "defect" in text.lower() or "缺陷" in text.lower():
|
| 201 |
+
defects.append(("缺陷", 0.88, [0.5, 0.5, 0.7, 0.7]))
|
| 202 |
|
| 203 |
+
# 如果没有找到缺陷,添加一个通用的
|
| 204 |
if not defects:
|
| 205 |
+
defects.append(("异常", 0.75, [0.4, 0.4, 0.6, 0.6]))
|
| 206 |
|
| 207 |
+
# 将归一化坐标转换为像素坐标
|
| 208 |
width, height = image_size
|
| 209 |
boxes = []
|
| 210 |
scores = []
|
requirements.txt
CHANGED
|
@@ -1,22 +1,24 @@
|
|
| 1 |
-
gradio
|
| 2 |
-
torch
|
| 3 |
-
transformers
|
| 4 |
-
|
| 5 |
-
numpy
|
| 6 |
matplotlib>=3.8.2
|
| 7 |
-
pandas
|
| 8 |
sqlalchemy>=2.0.25
|
| 9 |
-
plotly
|
| 10 |
scikit-learn>=1.3.2
|
| 11 |
jinja2>=3.1.3
|
| 12 |
-
huggingface-hub
|
| 13 |
python-dotenv>=1.0.0
|
| 14 |
markdown>=3.5.1
|
| 15 |
-
psutil
|
| 16 |
tqdm>=4.66.1
|
| 17 |
-
accelerate
|
| 18 |
safetensors>=0.4.1
|
| 19 |
peft>=0.7.1
|
| 20 |
optimum>=1.14.0
|
| 21 |
colorama>=0.4.6
|
| 22 |
-
rich>=13.7.0
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.19.2
|
| 2 |
+
torch==2.1.2
|
| 3 |
+
transformers==4.37.2
|
| 4 |
+
pillow==10.1.0
|
| 5 |
+
numpy==1.26.2
|
| 6 |
matplotlib>=3.8.2
|
| 7 |
+
pandas==2.1.3
|
| 8 |
sqlalchemy>=2.0.25
|
| 9 |
+
plotly==5.18.0
|
| 10 |
scikit-learn>=1.3.2
|
| 11 |
jinja2>=3.1.3
|
| 12 |
+
huggingface-hub==0.20.2
|
| 13 |
python-dotenv>=1.0.0
|
| 14 |
markdown>=3.5.1
|
| 15 |
+
psutil==5.9.6
|
| 16 |
tqdm>=4.66.1
|
| 17 |
+
accelerate==0.25.0
|
| 18 |
safetensors>=0.4.1
|
| 19 |
peft>=0.7.1
|
| 20 |
optimum>=1.14.0
|
| 21 |
colorama>=0.4.6
|
| 22 |
+
rich>=13.7.0
|
| 23 |
+
bitsandbytes==0.41.1
|
| 24 |
+
scipy>=1.11.3
|
run.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
# Load environment variables from .env file if it exists
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
# Configure logging
|
| 10 |
+
logging.basicConfig(level=logging.INFO,
|
| 11 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
def main():
|
| 15 |
+
"""Run the Radar Analysis System application"""
|
| 16 |
+
try:
|
| 17 |
+
# Check for HF_TOKEN environment variable
|
| 18 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 19 |
+
if not hf_token:
|
| 20 |
+
logger.warning("HF_TOKEN environment variable not set. The application will run in demo mode.")
|
| 21 |
+
else:
|
| 22 |
+
logger.info("HF_TOKEN environment variable found.")
|
| 23 |
+
|
| 24 |
+
# Import app module
|
| 25 |
+
import app
|
| 26 |
+
|
| 27 |
+
# Launch the application
|
| 28 |
+
logger.info("Starting Radar Analysis System...")
|
| 29 |
+
app.launch()
|
| 30 |
+
|
| 31 |
+
return True
|
| 32 |
+
except Exception as e:
|
| 33 |
+
logger.error(f"Error running application: {str(e)}")
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
success = main()
|
| 38 |
+
sys.exit(0 if success else 1)
|
test_app.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
# Configure logging
|
| 6 |
+
logging.basicConfig(level=logging.INFO,
|
| 7 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
def test_imports():
|
| 11 |
+
"""测试所有必需的模块都可以导入"""
|
| 12 |
+
try:
|
| 13 |
+
import torch
|
| 14 |
+
logger.info(f"PyTorch版本: {torch.__version__}")
|
| 15 |
+
|
| 16 |
+
import transformers
|
| 17 |
+
logger.info(f"Transformers版本: {transformers.__version__}")
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
logger.info(f"NumPy版本: {np.__version__}")
|
| 21 |
+
|
| 22 |
+
import PIL
|
| 23 |
+
logger.info(f"PIL版本: {PIL.__version__}")
|
| 24 |
+
|
| 25 |
+
import scipy
|
| 26 |
+
logger.info(f"SciPy版本: {scipy.__version__}")
|
| 27 |
+
|
| 28 |
+
logger.info("所有导入成功")
|
| 29 |
+
return True
|
| 30 |
+
except ImportError as e:
|
| 31 |
+
logger.error(f"导入错误: {str(e)}")
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
def test_model_loading():
|
| 35 |
+
"""测试模型可以加载"""
|
| 36 |
+
try:
|
| 37 |
+
from model import RadarDetectionModel
|
| 38 |
+
|
| 39 |
+
# 检查是否设置了HF_TOKEN环境变量
|
| 40 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 41 |
+
if not hf_token:
|
| 42 |
+
logger.warning("未设置HF_TOKEN环境变量,使用公共模型进行测试")
|
| 43 |
+
|
| 44 |
+
# 尝试初始化模型,使用较小的公共模型
|
| 45 |
+
logger.info("尝试初始化模型(使用较小的公共模型)")
|
| 46 |
+
model = RadarDetectionModel(model_name="google/siglip-base-patch16-224")
|
| 47 |
+
logger.info("模型初始化成功")
|
| 48 |
+
return True
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logger.error(f"模型加载错误: {str(e)}")
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
def test_feature_extraction():
|
| 54 |
+
"""测试特征提取功能"""
|
| 55 |
+
try:
|
| 56 |
+
import numpy as np
|
| 57 |
+
from PIL import Image
|
| 58 |
+
from feature_extraction import extract_features
|
| 59 |
+
|
| 60 |
+
# 创建一个虚拟图像和检测结果
|
| 61 |
+
logger.info("创建虚拟测试数据")
|
| 62 |
+
dummy_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
|
| 63 |
+
dummy_detection = {
|
| 64 |
+
'boxes': [[50, 50, 100, 100]],
|
| 65 |
+
'scores': [0.9],
|
| 66 |
+
'labels': ['测试']
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
# 提取特征
|
| 70 |
+
logger.info("提取特征")
|
| 71 |
+
features = extract_features(dummy_image, dummy_detection)
|
| 72 |
+
logger.info(f"提取的特征: {features}")
|
| 73 |
+
return True
|
| 74 |
+
except Exception as e:
|
| 75 |
+
logger.error(f"特征提取错误: {str(e)}")
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
def test_app_initialization():
|
| 79 |
+
"""测试应用程序初始化但不加载模型"""
|
| 80 |
+
try:
|
| 81 |
+
logger.info("测试应用程序初始化")
|
| 82 |
+
import app
|
| 83 |
+
|
| 84 |
+
# 检查应用程序是否已初始化但没有加载模型
|
| 85 |
+
logger.info("检查应用程序全局变量")
|
| 86 |
+
assert app.model is None, "模型不应该在导入时加载"
|
| 87 |
+
assert app.MODEL_INIT_ATTEMPTED is False, "模型初始化尝试标志应为False"
|
| 88 |
+
|
| 89 |
+
logger.info("应用程序初始化测试通过")
|
| 90 |
+
return True
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.error(f"应用程序初始化错误: {str(e)}")
|
| 93 |
+
return False
|
| 94 |
+
|
| 95 |
+
def run_tests():
|
| 96 |
+
"""运行所有测试"""
|
| 97 |
+
tests = [
|
| 98 |
+
("导入测试", test_imports),
|
| 99 |
+
("应用程序初始化测试", test_app_initialization),
|
| 100 |
+
("模型加载测试", test_model_loading),
|
| 101 |
+
("特征提取测试", test_feature_extraction)
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
results = []
|
| 105 |
+
for name, test_func in tests:
|
| 106 |
+
logger.info(f"运行{name}...")
|
| 107 |
+
try:
|
| 108 |
+
result = test_func()
|
| 109 |
+
results.append((name, result))
|
| 110 |
+
logger.info(f"{name}: {'通过' if result else '失败'}")
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error(f"{name}失败,错误: {str(e)}")
|
| 113 |
+
results.append((name, False))
|
| 114 |
+
|
| 115 |
+
# 打印摘要
|
| 116 |
+
logger.info("\n--- 测试摘要 ---")
|
| 117 |
+
passed = sum(1 for _, result in results if result)
|
| 118 |
+
total = len(results)
|
| 119 |
+
logger.info(f"通过: {passed}/{total} 测试")
|
| 120 |
+
|
| 121 |
+
for name, result in results:
|
| 122 |
+
status = "通过" if result else "失败"
|
| 123 |
+
logger.info(f"{name}: {status}")
|
| 124 |
+
|
| 125 |
+
return passed == total
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
success = run_tests()
|
| 129 |
+
sys.exit(0 if success else 1)
|
utils.py
CHANGED
|
@@ -9,10 +9,12 @@ def plot_detection(image, detection_result):
|
|
| 9 |
ax = plt.gca()
|
| 10 |
|
| 11 |
for score, label, box in zip(detection_result["scores"], detection_result["labels"], detection_result["boxes"]):
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
ax.add_patch(rect)
|
| 15 |
-
ax.text(
|
| 16 |
bbox=dict(facecolor='white', alpha=0.8))
|
| 17 |
|
| 18 |
plt.axis('off')
|
|
|
|
| 9 |
ax = plt.gca()
|
| 10 |
|
| 11 |
for score, label, box in zip(detection_result["scores"], detection_result["labels"], detection_result["boxes"]):
|
| 12 |
+
x1, y1, x2, y2 = box
|
| 13 |
+
width = x2 - x1
|
| 14 |
+
height = y2 - y1
|
| 15 |
+
rect = plt.Rectangle((x1, y1), width, height, fill=False, color='red', linewidth=2)
|
| 16 |
ax.add_patch(rect)
|
| 17 |
+
ax.text(x1, y1, f'{label}: {score:.2f}',
|
| 18 |
bbox=dict(facecolor='white', alpha=0.8))
|
| 19 |
|
| 20 |
plt.axis('off')
|