Kesheratmex commited on
Commit
cbd697d
Β·
1 Parent(s): 96316b4

**Make infer_media configurable, return counts and support scaling**

Browse files

- Added optional parameters (`conf`, `iou`, `out_res`, `preset`) to `infer_media` for flexible inference settings.
- Implemented validation for required `media_path`.
- Introduced resolution mapping and optional resizing of output frames/images.
- Added per‑class counting for both video and image processing, returning a dictionary with paths and class counts.
- Improved video FPS handling with fallback and NaN checks.
- Updated UI: added hidden JSON components to expose the new dictionary results.
- Adjusted drawing code and file handling to work with the new return structure.

Files changed (1) hide show
  1. app.py +60 -17
app.py CHANGED
@@ -14,29 +14,53 @@ model = YOLO("best2.pt") # carga el modelo UNA sola vez
14
  # ────────────────────────────
15
  # Funciones de Inferencia
16
  # ────────────────────────────
17
- def infer_media(media_path):
18
  """
19
- Procesa un fichero de vΓ­deo o imagen:
20
- - Si es vΓ­deo, lo anota frame a frame y devuelve un MP4.
21
- - Si es imagen, dibuja las cajas sobre la imagen y devuelve un array BGR.
 
22
  """
 
 
 
23
  ext = os.path.splitext(media_path)[1].lower()
24
  tmpdir = tempfile.mkdtemp()
25
 
 
 
 
 
26
  # ─ VΓ­deo ───────────────────────────────────────────────────────
27
  if ext in [".mp4", ".mov", ".avi", ".mkv"]:
28
  in_vid = os.path.join(tmpdir, "in.mp4")
29
  out_vid = os.path.join(tmpdir, "out.mp4")
30
  shutil.copy(media_path, in_vid)
31
 
32
- # Preparamos writer de vΓ­deo
 
 
 
 
 
 
 
 
 
33
  writer = None
34
- fps = 30 # ajΓΊstalo si tu vΓ­deo tiene otro FPS
35
 
36
- # Streaming de frames con anotaciones
37
- results = model.predict(source=in_vid, conf=0.25, iou=0.45, stream=True)
38
  for r in results:
 
 
 
 
 
39
  annotated = r.plot() # frame anotado
 
 
40
 
41
  if writer is None:
42
  h, w = annotated.shape[:2]
@@ -47,30 +71,40 @@ def infer_media(media_path):
47
 
48
  if writer:
49
  writer.release()
 
 
50
 
51
- return out_vid
52
 
53
  # ─ Imagen ──────────────────────────────────────────────────────
54
  elif ext in [".jpg", ".jpeg", ".png", ".bmp"]:
55
  img = cv2.imread(media_path)
56
- results = model.predict(source=media_path, conf=0.25, iou=0.45, save=False)
57
 
58
- # Dibujamos cajas manualmente
 
59
  for box in results[0].boxes:
60
  x1, y1, x2, y2 = map(int, box.xyxy[0])
61
  cls_id = int(box.cls[0])
62
  label = model.names[cls_id]
 
63
  # rectΓ‘ngulo
64
- cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)
65
  # texto
66
  cv2.putText(img, label, (x1, y1 - 10),
67
- cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0,255,0), 2)
 
 
 
68
 
69
- return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 
 
70
 
71
  else:
72
  raise ValueError(f"Formato no soportado: {ext}")
73
 
 
74
  def show_classes():
75
  """Devuelve las clases que el modelo conoce."""
76
  names = model.names
@@ -98,10 +132,19 @@ with gr.Blocks(title="Kesherat Β· InspecciΓ³n de palas eΓ³licas") as demo:
98
  output_video = gr.Video(label="VΓ­deo anotado")
99
  output_image = gr.Image(label="Imagen anotada")
100
 
 
 
 
 
101
  btn_detect = gr.Button("Detectar defectos")
102
- # Conectamos ambos inputs al mismo infer_media, con salidas condicionadas
103
- btn_detect.click(fn=infer_media, inputs=video_input, outputs=output_video)
104
- btn_detect.click(fn=infer_media, inputs=image_input, outputs=output_image)
 
 
 
 
 
105
 
106
  btn_classes = gr.Button("Mostrar clases del modelo")
107
  txt_classes = gr.Textbox(label="Clases cargadas", interactive=False)
 
14
  # ────────────────────────────
15
  # Funciones de Inferencia
16
  # ────────────────────────────
17
+ def infer_media(media_path, conf=0.25, iou=0.45, out_res="720p", preset="default"):
18
  """
19
+ Procesa un fichero de vΓ­deo o imagen con parΓ‘metros configurables.
20
+ Retornos:
21
+ - VΓ­deo: {"video": out_vid_path, "classes": {label: count, ...}}
22
+ - Imagen: {"path": out_img_path, "classes": {label: count, ...}}
23
  """
24
+ if not media_path:
25
+ raise ValueError("media_path es requerido")
26
+
27
  ext = os.path.splitext(media_path)[1].lower()
28
  tmpdir = tempfile.mkdtemp()
29
 
30
+ # ResoluciΓ³n objetivo
31
+ res_map = {"360p": (640, 360), "480p": (854, 480), "720p": (1280, 720)}
32
+ target_size = res_map.get(out_res)
33
+
34
  # ─ VΓ­deo ───────────────────────────────────────────────────────
35
  if ext in [".mp4", ".mov", ".avi", ".mkv"]:
36
  in_vid = os.path.join(tmpdir, "in.mp4")
37
  out_vid = os.path.join(tmpdir, "out.mp4")
38
  shutil.copy(media_path, in_vid)
39
 
40
+ # FPS del vΓ­deo (opcional: tomar real si existe)
41
+ cap = cv2.VideoCapture(in_vid)
42
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
43
+ try:
44
+ fps = float(fps)
45
+ if fps <= 0 or fps != fps: # NaN check
46
+ fps = 30
47
+ except Exception:
48
+ fps = 30
49
+
50
  writer = None
51
+ counts = {}
52
 
53
+ # Streaming de frames con anotaciones y conteo por clase
54
+ results = model.predict(source=in_vid, conf=conf, iou=iou, stream=True)
55
  for r in results:
56
+ # acumular conteos
57
+ for b in r.boxes:
58
+ label = model.names[int(b.cls[0])]
59
+ counts[label] = counts.get(label, 0) + 1
60
+
61
  annotated = r.plot() # frame anotado
62
+ if target_size:
63
+ annotated = cv2.resize(annotated, target_size)
64
 
65
  if writer is None:
66
  h, w = annotated.shape[:2]
 
71
 
72
  if writer:
73
  writer.release()
74
+ if cap:
75
+ cap.release()
76
 
77
+ return {"video": out_vid, "classes": counts}
78
 
79
  # ─ Imagen ──────────────────────────────────────────────────────
80
  elif ext in [".jpg", ".jpeg", ".png", ".bmp"]:
81
  img = cv2.imread(media_path)
82
+ results = model.predict(source=media_path, conf=conf, iou=iou, save=False)
83
 
84
+ counts = {}
85
+ # Dibujamos cajas manualmente y contamos
86
  for box in results[0].boxes:
87
  x1, y1, x2, y2 = map(int, box.xyxy[0])
88
  cls_id = int(box.cls[0])
89
  label = model.names[cls_id]
90
+ counts[label] = counts.get(label, 0) + 1
91
  # rectΓ‘ngulo
92
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
93
  # texto
94
  cv2.putText(img, label, (x1, y1 - 10),
95
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
96
+
97
+ if target_size:
98
+ img = cv2.resize(img, target_size)
99
 
100
+ out_path = os.path.join(tmpdir, "annotated.png")
101
+ cv2.imwrite(out_path, img)
102
+ return {"path": out_path, "classes": counts}
103
 
104
  else:
105
  raise ValueError(f"Formato no soportado: {ext}")
106
 
107
+
108
  def show_classes():
109
  """Devuelve las clases que el modelo conoce."""
110
  names = model.names
 
132
  output_video = gr.Video(label="VΓ­deo anotado")
133
  output_image = gr.Image(label="Imagen anotada")
134
 
135
+ # Componentes JSON ocultos para soportar API devolviendo dict y encadenar a la UI
136
+ json_video = gr.JSON(visible=False)
137
+ json_image = gr.JSON(visible=False)
138
+
139
  btn_detect = gr.Button("Detectar defectos")
140
+
141
+ # Endpoint API para vΓ­deo: devuelve dict {video, classes}. UI: extrae solo el vΓ­deo.
142
+ ev_video = btn_detect.click(fn=infer_media, inputs=video_input, outputs=json_video, api_name="/infer_media")
143
+ ev_video.then(lambda d: (d.get("video") if isinstance(d, dict) else d), inputs=json_video, outputs=output_video)
144
+
145
+ # Endpoint API para imagen: devuelve dict {path, classes}. UI: extrae solo la imagen.
146
+ ev_image = btn_detect.click(fn=infer_media, inputs=image_input, outputs=json_image, api_name="/infer_media_1")
147
+ ev_image.then(lambda d: (d.get("path") if isinstance(d, dict) else d), inputs=json_image, outputs=output_image)
148
 
149
  btn_classes = gr.Button("Mostrar clases del modelo")
150
  txt_classes = gr.Textbox(label="Clases cargadas", interactive=False)