tyrwh commited on
Commit
a44ba26
·
1 Parent(s): 8beb05b

Retrying gitignore update, 3rd attempt

Browse files
Files changed (2) hide show
  1. .gitignore +2 -1
  2. app.py +100 -57
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  .DS_Store
2
  weights.pt
3
  weights_nemaquant.v1.onnx
4
- results/
 
 
1
  .DS_Store
2
  weights.pt
3
  weights_nemaquant.v1.onnx
4
+ results/
5
+ *.pyc
app.py CHANGED
@@ -19,6 +19,9 @@ import zipfile
19
  import cv2
20
  import csv
21
  import numpy as np
 
 
 
22
 
23
  from yolo_utils import load_model, detect_image
24
 
@@ -35,7 +38,8 @@ app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg', 'tif', 'tiff'}
35
  UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True)
36
  RESULT_FOLDER.mkdir(parents=True, exist_ok=True)
37
 
38
- job_status = {}
 
39
 
40
  @app.errorhandler(Exception)
41
  def handle_exception(e):
@@ -58,40 +62,65 @@ def get_model():
58
  _model = load_model(WEIGHTS_FILE)
59
  return _model
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  all_detections = {}
62
 
63
  def process_image(args):
64
- filename, image_bytes = args
65
  model = get_model()
66
  detections = detect_image(model, image_bytes, conf=0.05)
67
- # Do NOT update all_detections here (worker process)
68
- # Save original image to uploads for later annotation
69
- img_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
70
- with open(img_path, 'wb') as f:
71
- f.write(image_bytes)
72
- return {'filename': filename, 'detections': detections}
73
 
74
  def async_process_images(job_id, file_data):
75
  try:
76
- job_status[job_id] = {'status': 'running', 'progress': 0, 'results': []}
 
 
 
77
  total = len(file_data)
78
  results = []
 
79
  with Pool(processes=min(cpu_count(), total)) as pool:
80
  for idx, result in enumerate(pool.imap(process_image, file_data)):
81
- results.append(result)
82
- # Update progress (0-100)
83
- job_status[job_id]['progress'] = int((idx + 1) / total * 100)
84
- # Aggregate results
85
- for result in results:
86
- all_detections[result['filename']] = result['detections']
87
- # Add num_eggs to each result for frontend compatibility
88
- for result in results:
89
- result['num_eggs'] = sum(1 for d in result['detections'] if d.get('class') == 'egg')
90
- job_status[job_id]['status'] = 'success'
91
- job_status[job_id]['results'] = results
92
- job_status[job_id]['progress'] = 100
 
93
  except Exception as e:
94
- job_status[job_id] = {'status': 'error', 'error': str(e), 'progress': 100}
 
 
 
 
95
 
96
  @app.route('/process', methods=['POST'])
97
  def process_images():
@@ -99,10 +128,19 @@ def process_images():
99
  files = request.files.getlist('files')
100
  if not files or files[0].filename == '':
101
  return jsonify({'error': 'No files uploaded'}), 400
102
-
103
- file_data = [(secure_filename(f.filename), f.read()) for f in files]
104
  job_id = str(uuid.uuid4())
105
- job_status[job_id] = {'status': 'starting', 'progress': 0}
 
 
 
 
 
 
 
 
 
 
 
106
  thread = Thread(target=async_process_images, args=(job_id, file_data))
107
  thread.daemon = True
108
  thread.start()
@@ -112,33 +150,33 @@ def process_images():
112
  print(traceback.format_exc())
113
  return jsonify({'error': str(e)}), 500
114
 
115
- @app.route('/annotate', methods=['POST'])
116
- def annotate_image():
117
- data = request.json
118
- filename = secure_filename(data['filename'])
119
- threshold = float(data['confidence'])
120
- img_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
121
- img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
122
- detections = all_detections.get(filename, [])
123
- filtered = [d for d in detections if d['score'] >= threshold]
124
- # Draw boxes
125
- for det in filtered:
126
- x1, y1, x2, y2 = map(int, det['bbox'])
127
- cv2.rectangle(img, (x1, y1), (x2, y2), (0,0,255), 3)
128
- temp_path = os.path.join(tempfile.gettempdir(), 'annotated.png')
129
- cv2.imwrite(temp_path, img)
130
- return send_file(temp_path, mimetype='image/png')
 
131
 
132
  @app.route('/progress/<job_id>')
133
  def get_progress(job_id):
134
- status = job_status.get(job_id)
135
- if not status:
136
  return jsonify({"status": "error", "error": "Job ID not found"}), 404
137
  # Add a mapping from filename to detections for frontend plotting
138
- if 'results' in status:
139
- detections_by_filename = {r['filename']: r['detections'] for r in status['results']}
140
- status['detections_by_filename'] = detections_by_filename
141
- return jsonify(status)
142
 
143
  @app.route('/results/<job_id>/<path:filename>')
144
  def download_file(job_id, filename):
@@ -245,12 +283,15 @@ def export_images(job_id):
245
  def export_csv():
246
  try:
247
  data = request.json
 
248
  threshold = float(data.get('confidence', 0.5))
249
- # all_detections: {filename: [detections]}
 
 
250
  rows = []
251
- for filename, detections in all_detections.items():
252
  count = sum(1 for d in detections if d['score'] >= threshold)
253
- rows.append({'Filename': filename, 'EggsDetected': count})
254
  timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
255
  output = io.StringIO()
256
  writer = csv.DictWriter(output, fieldnames=['Filename', 'EggsDetected'])
@@ -273,20 +314,22 @@ def export_csv():
273
  def export_images_post():
274
  try:
275
  data = request.json
 
276
  threshold = float(data.get('confidence', 0.5))
277
- # all_detections: {filename: [detections]}
 
 
278
  memory_file = io.BytesIO()
279
  with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
280
- for filename, detections in all_detections.items():
281
- filtered = [d for d in detections if d['score'] >= threshold]
282
- img_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
283
  img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
 
284
  for det in filtered:
285
  x1, y1, x2, y2 = map(int, det['bbox'])
286
  cv2.rectangle(img, (x1, y1), (x2, y2), (0,0,255), 3)
287
- # Save annotated image to memory
288
- is_tiff = filename.lower().endswith(('.tif', '.tiff'))
289
- out_name = f"{Path(filename).stem}.png"
290
  _, img_bytes = cv2.imencode('.png', img)
291
  zf.writestr(out_name, img_bytes.tobytes())
292
  memory_file.seek(0)
 
19
  import cv2
20
  import csv
21
  import numpy as np
22
+ import redis
23
+ import json
24
+ import shutil
25
 
26
  from yolo_utils import load_model, detect_image
27
 
 
38
  UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True)
39
  RESULT_FOLDER.mkdir(parents=True, exist_ok=True)
40
 
41
+ # Redis client (localhost:6379, db=0, no password)
42
+ redis_client = redis.Redis(host='localhost', port=6379, db=0, decode_responses=True)
43
 
44
  @app.errorhandler(Exception)
45
  def handle_exception(e):
 
62
  _model = load_model(WEIGHTS_FILE)
63
  return _model
64
 
65
+ def cleanup_job(job_id):
66
+ # Remove files
67
+ upload_dir = os.path.join(app.config['UPLOAD_FOLDER'], job_id)
68
+ if os.path.exists(upload_dir):
69
+ shutil.rmtree(upload_dir)
70
+ # Remove Redis state
71
+ redis_client.delete(f"job:{job_id}")
72
+
73
+ @app.route('/cleanup/<job_id>', methods=['POST'])
74
+ def cleanup_job_endpoint(job_id):
75
+ cleanup_job(job_id)
76
+ return jsonify({'status': 'cleaned'})
77
+
78
+ def get_job_state(job_id):
79
+ data = redis_client.get(f"job:{job_id}")
80
+ return json.loads(data) if data else None
81
+
82
+ def set_job_state(job_id, state):
83
+ redis_client.set(f"job:{job_id}", json.dumps(state))
84
+
85
  all_detections = {}
86
 
87
  def process_image(args):
88
+ orig_name, unique_name, image_bytes = args
89
  model = get_model()
90
  detections = detect_image(model, image_bytes, conf=0.05)
91
+ # Save original image to uploads for later annotation (already saved)
92
+ return {'orig_name': orig_name, 'unique_name': unique_name, 'detections': detections}
 
 
 
 
93
 
94
  def async_process_images(job_id, file_data):
95
  try:
96
+ job_state = get_job_state(job_id)
97
+ job_state['status'] = 'running'
98
+ job_state['progress'] = 0
99
+ set_job_state(job_id, job_state)
100
  total = len(file_data)
101
  results = []
102
+ detections = {}
103
  with Pool(processes=min(cpu_count(), total)) as pool:
104
  for idx, result in enumerate(pool.imap(process_image, file_data)):
105
+ results.append({
106
+ 'filename': result['orig_name'],
107
+ 'num_eggs': sum(1 for d in result['detections'] if d.get('class') == 'egg'),
108
+ })
109
+ detections[result['orig_name']] = result['detections']
110
+ # Update progress
111
+ job_state['progress'] = int((idx + 1) / total * 100)
112
+ set_job_state(job_id, job_state)
113
+ job_state['status'] = 'success'
114
+ job_state['results'] = results
115
+ job_state['detections'] = detections
116
+ job_state['progress'] = 100
117
+ set_job_state(job_id, job_state)
118
  except Exception as e:
119
+ job_state = get_job_state(job_id) or {}
120
+ job_state['status'] = 'error'
121
+ job_state['error'] = str(e)
122
+ job_state['progress'] = 100
123
+ set_job_state(job_id, job_state)
124
 
125
  @app.route('/process', methods=['POST'])
126
  def process_images():
 
128
  files = request.files.getlist('files')
129
  if not files or files[0].filename == '':
130
  return jsonify({'error': 'No files uploaded'}), 400
 
 
131
  job_id = str(uuid.uuid4())
132
+ # Clean up any previous state for this job
133
+ cleanup_job(job_id)
134
+ filename_map, file_data = save_uploaded_files(files, job_id)
135
+ # Store initial job state in Redis
136
+ job_state = {
137
+ 'status': 'starting',
138
+ 'progress': 0,
139
+ 'results': [],
140
+ 'filename_map': filename_map,
141
+ 'detections': {},
142
+ }
143
+ set_job_state(job_id, job_state)
144
  thread = Thread(target=async_process_images, args=(job_id, file_data))
145
  thread.daemon = True
146
  thread.start()
 
150
  print(traceback.format_exc())
151
  return jsonify({'error': str(e)}), 500
152
 
153
+ def save_uploaded_files(files, job_id):
154
+ upload_dir = os.path.join(app.config['UPLOAD_FOLDER'], job_id)
155
+ if os.path.exists(upload_dir):
156
+ shutil.rmtree(upload_dir)
157
+ os.makedirs(upload_dir, exist_ok=True)
158
+ filename_map = {}
159
+ file_data = []
160
+ for f in files:
161
+ orig_name = secure_filename(f.filename)
162
+ ext = os.path.splitext(orig_name)[1]
163
+ unique_name = f"{uuid.uuid4().hex}{ext}"
164
+ file_path = os.path.join(upload_dir, unique_name)
165
+ f.save(file_path)
166
+ filename_map[orig_name] = unique_name
167
+ with open(file_path, 'rb') as imgf:
168
+ file_data.append((orig_name, unique_name, imgf.read()))
169
+ return filename_map, file_data
170
 
171
  @app.route('/progress/<job_id>')
172
  def get_progress(job_id):
173
+ job_state = get_job_state(job_id)
174
+ if not job_state:
175
  return jsonify({"status": "error", "error": "Job ID not found"}), 404
176
  # Add a mapping from filename to detections for frontend plotting
177
+ if 'detections' in job_state:
178
+ job_state['detections_by_filename'] = job_state['detections']
179
+ return jsonify(job_state)
 
180
 
181
  @app.route('/results/<job_id>/<path:filename>')
182
  def download_file(job_id, filename):
 
283
  def export_csv():
284
  try:
285
  data = request.json
286
+ job_id = data['jobId']
287
  threshold = float(data.get('confidence', 0.5))
288
+ job_state = get_job_state(job_id)
289
+ if not job_state:
290
+ return jsonify({'error': 'Job not found'}), 404
291
  rows = []
292
+ for orig_name, detections in job_state['detections'].items():
293
  count = sum(1 for d in detections if d['score'] >= threshold)
294
+ rows.append({'Filename': orig_name, 'EggsDetected': count})
295
  timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
296
  output = io.StringIO()
297
  writer = csv.DictWriter(output, fieldnames=['Filename', 'EggsDetected'])
 
314
  def export_images_post():
315
  try:
316
  data = request.json
317
+ job_id = data['jobId']
318
  threshold = float(data.get('confidence', 0.5))
319
+ job_state = get_job_state(job_id)
320
+ if not job_state:
321
+ return jsonify({'error': 'Job not found'}), 404
322
  memory_file = io.BytesIO()
323
  with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
324
+ for orig_name, detections in job_state['detections'].items():
325
+ unique_name = job_state['filename_map'][orig_name]
326
+ img_path = os.path.join(app.config['UPLOAD_FOLDER'], job_id, unique_name)
327
  img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
328
+ filtered = [d for d in detections if d['score'] >= threshold]
329
  for det in filtered:
330
  x1, y1, x2, y2 = map(int, det['bbox'])
331
  cv2.rectangle(img, (x1, y1), (x2, y2), (0,0,255), 3)
332
+ out_name = f"{Path(orig_name).stem}.png"
 
 
333
  _, img_bytes = cv2.imencode('.png', img)
334
  zf.writestr(out_name, img_bytes.tobytes())
335
  memory_file.seek(0)