Spaces:
Sleeping
Sleeping
Kesheratmex
commited on
Commit
Β·
cbd697d
1
Parent(s):
96316b4
**Make infer_media configurable, return counts and support scaling**
Browse files- Added optional parameters (`conf`, `iou`, `out_res`, `preset`) to `infer_media` for flexible inference settings.
- Implemented validation for required `media_path`.
- Introduced resolution mapping and optional resizing of output frames/images.
- Added perβclass counting for both video and image processing, returning a dictionary with paths and class counts.
- Improved video FPS handling with fallback and NaN checks.
- Updated UI: added hidden JSON components to expose the new dictionary results.
- Adjusted drawing code and file handling to work with the new return structure.
app.py
CHANGED
|
@@ -14,29 +14,53 @@ model = YOLO("best2.pt") # carga el modelo UNA sola vez
|
|
| 14 |
# ββββββββββββββββββββββββββββ
|
| 15 |
# Funciones de Inferencia
|
| 16 |
# ββββββββββββββββββββββββββββ
|
| 17 |
-
def infer_media(media_path):
|
| 18 |
"""
|
| 19 |
-
Procesa un fichero de vΓdeo o imagen
|
| 20 |
-
|
| 21 |
-
|
|
|
|
| 22 |
"""
|
|
|
|
|
|
|
|
|
|
| 23 |
ext = os.path.splitext(media_path)[1].lower()
|
| 24 |
tmpdir = tempfile.mkdtemp()
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# β VΓdeo βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
if ext in [".mp4", ".mov", ".avi", ".mkv"]:
|
| 28 |
in_vid = os.path.join(tmpdir, "in.mp4")
|
| 29 |
out_vid = os.path.join(tmpdir, "out.mp4")
|
| 30 |
shutil.copy(media_path, in_vid)
|
| 31 |
|
| 32 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
writer = None
|
| 34 |
-
|
| 35 |
|
| 36 |
-
# Streaming de frames con anotaciones
|
| 37 |
-
results = model.predict(source=in_vid, conf=
|
| 38 |
for r in results:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
annotated = r.plot() # frame anotado
|
|
|
|
|
|
|
| 40 |
|
| 41 |
if writer is None:
|
| 42 |
h, w = annotated.shape[:2]
|
|
@@ -47,30 +71,40 @@ def infer_media(media_path):
|
|
| 47 |
|
| 48 |
if writer:
|
| 49 |
writer.release()
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
return out_vid
|
| 52 |
|
| 53 |
# β Imagen ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
elif ext in [".jpg", ".jpeg", ".png", ".bmp"]:
|
| 55 |
img = cv2.imread(media_path)
|
| 56 |
-
results = model.predict(source=media_path, conf=
|
| 57 |
|
| 58 |
-
|
|
|
|
| 59 |
for box in results[0].boxes:
|
| 60 |
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
| 61 |
cls_id = int(box.cls[0])
|
| 62 |
label = model.names[cls_id]
|
|
|
|
| 63 |
# rectΓ‘ngulo
|
| 64 |
-
cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)
|
| 65 |
# texto
|
| 66 |
cv2.putText(img, label, (x1, y1 - 10),
|
| 67 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0,255,0), 2)
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
|
| 71 |
else:
|
| 72 |
raise ValueError(f"Formato no soportado: {ext}")
|
| 73 |
|
|
|
|
| 74 |
def show_classes():
|
| 75 |
"""Devuelve las clases que el modelo conoce."""
|
| 76 |
names = model.names
|
|
@@ -98,10 +132,19 @@ with gr.Blocks(title="Kesherat Β· InspecciΓ³n de palas eΓ³licas") as demo:
|
|
| 98 |
output_video = gr.Video(label="VΓdeo anotado")
|
| 99 |
output_image = gr.Image(label="Imagen anotada")
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
btn_detect = gr.Button("Detectar defectos")
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
btn_detect.click(fn=infer_media, inputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
btn_classes = gr.Button("Mostrar clases del modelo")
|
| 107 |
txt_classes = gr.Textbox(label="Clases cargadas", interactive=False)
|
|
|
|
| 14 |
# ββββββββββββββββββββββββββββ
|
| 15 |
# Funciones de Inferencia
|
| 16 |
# ββββββββββββββββββββββββββββ
|
| 17 |
+
def infer_media(media_path, conf=0.25, iou=0.45, out_res="720p", preset="default"):
|
| 18 |
"""
|
| 19 |
+
Procesa un fichero de vΓdeo o imagen con parΓ‘metros configurables.
|
| 20 |
+
Retornos:
|
| 21 |
+
- VΓdeo: {"video": out_vid_path, "classes": {label: count, ...}}
|
| 22 |
+
- Imagen: {"path": out_img_path, "classes": {label: count, ...}}
|
| 23 |
"""
|
| 24 |
+
if not media_path:
|
| 25 |
+
raise ValueError("media_path es requerido")
|
| 26 |
+
|
| 27 |
ext = os.path.splitext(media_path)[1].lower()
|
| 28 |
tmpdir = tempfile.mkdtemp()
|
| 29 |
|
| 30 |
+
# ResoluciΓ³n objetivo
|
| 31 |
+
res_map = {"360p": (640, 360), "480p": (854, 480), "720p": (1280, 720)}
|
| 32 |
+
target_size = res_map.get(out_res)
|
| 33 |
+
|
| 34 |
# β VΓdeo βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
if ext in [".mp4", ".mov", ".avi", ".mkv"]:
|
| 36 |
in_vid = os.path.join(tmpdir, "in.mp4")
|
| 37 |
out_vid = os.path.join(tmpdir, "out.mp4")
|
| 38 |
shutil.copy(media_path, in_vid)
|
| 39 |
|
| 40 |
+
# FPS del vΓdeo (opcional: tomar real si existe)
|
| 41 |
+
cap = cv2.VideoCapture(in_vid)
|
| 42 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
| 43 |
+
try:
|
| 44 |
+
fps = float(fps)
|
| 45 |
+
if fps <= 0 or fps != fps: # NaN check
|
| 46 |
+
fps = 30
|
| 47 |
+
except Exception:
|
| 48 |
+
fps = 30
|
| 49 |
+
|
| 50 |
writer = None
|
| 51 |
+
counts = {}
|
| 52 |
|
| 53 |
+
# Streaming de frames con anotaciones y conteo por clase
|
| 54 |
+
results = model.predict(source=in_vid, conf=conf, iou=iou, stream=True)
|
| 55 |
for r in results:
|
| 56 |
+
# acumular conteos
|
| 57 |
+
for b in r.boxes:
|
| 58 |
+
label = model.names[int(b.cls[0])]
|
| 59 |
+
counts[label] = counts.get(label, 0) + 1
|
| 60 |
+
|
| 61 |
annotated = r.plot() # frame anotado
|
| 62 |
+
if target_size:
|
| 63 |
+
annotated = cv2.resize(annotated, target_size)
|
| 64 |
|
| 65 |
if writer is None:
|
| 66 |
h, w = annotated.shape[:2]
|
|
|
|
| 71 |
|
| 72 |
if writer:
|
| 73 |
writer.release()
|
| 74 |
+
if cap:
|
| 75 |
+
cap.release()
|
| 76 |
|
| 77 |
+
return {"video": out_vid, "classes": counts}
|
| 78 |
|
| 79 |
# β Imagen ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
elif ext in [".jpg", ".jpeg", ".png", ".bmp"]:
|
| 81 |
img = cv2.imread(media_path)
|
| 82 |
+
results = model.predict(source=media_path, conf=conf, iou=iou, save=False)
|
| 83 |
|
| 84 |
+
counts = {}
|
| 85 |
+
# Dibujamos cajas manualmente y contamos
|
| 86 |
for box in results[0].boxes:
|
| 87 |
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
| 88 |
cls_id = int(box.cls[0])
|
| 89 |
label = model.names[cls_id]
|
| 90 |
+
counts[label] = counts.get(label, 0) + 1
|
| 91 |
# rectΓ‘ngulo
|
| 92 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 93 |
# texto
|
| 94 |
cv2.putText(img, label, (x1, y1 - 10),
|
| 95 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
| 96 |
+
|
| 97 |
+
if target_size:
|
| 98 |
+
img = cv2.resize(img, target_size)
|
| 99 |
|
| 100 |
+
out_path = os.path.join(tmpdir, "annotated.png")
|
| 101 |
+
cv2.imwrite(out_path, img)
|
| 102 |
+
return {"path": out_path, "classes": counts}
|
| 103 |
|
| 104 |
else:
|
| 105 |
raise ValueError(f"Formato no soportado: {ext}")
|
| 106 |
|
| 107 |
+
|
| 108 |
def show_classes():
|
| 109 |
"""Devuelve las clases que el modelo conoce."""
|
| 110 |
names = model.names
|
|
|
|
| 132 |
output_video = gr.Video(label="VΓdeo anotado")
|
| 133 |
output_image = gr.Image(label="Imagen anotada")
|
| 134 |
|
| 135 |
+
# Componentes JSON ocultos para soportar API devolviendo dict y encadenar a la UI
|
| 136 |
+
json_video = gr.JSON(visible=False)
|
| 137 |
+
json_image = gr.JSON(visible=False)
|
| 138 |
+
|
| 139 |
btn_detect = gr.Button("Detectar defectos")
|
| 140 |
+
|
| 141 |
+
# Endpoint API para vΓdeo: devuelve dict {video, classes}. UI: extrae solo el vΓdeo.
|
| 142 |
+
ev_video = btn_detect.click(fn=infer_media, inputs=video_input, outputs=json_video, api_name="/infer_media")
|
| 143 |
+
ev_video.then(lambda d: (d.get("video") if isinstance(d, dict) else d), inputs=json_video, outputs=output_video)
|
| 144 |
+
|
| 145 |
+
# Endpoint API para imagen: devuelve dict {path, classes}. UI: extrae solo la imagen.
|
| 146 |
+
ev_image = btn_detect.click(fn=infer_media, inputs=image_input, outputs=json_image, api_name="/infer_media_1")
|
| 147 |
+
ev_image.then(lambda d: (d.get("path") if isinstance(d, dict) else d), inputs=json_image, outputs=output_image)
|
| 148 |
|
| 149 |
btn_classes = gr.Button("Mostrar clases del modelo")
|
| 150 |
txt_classes = gr.Textbox(label="Clases cargadas", interactive=False)
|