Instructions to use mKartux/BanNano-model with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use mKartux/BanNano-model with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://mKartux/BanNano-model") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Fruit Quality Classifier — FastAPI Backend v2 | |
| ============================================= | |
| CORRECCIONES v2: | |
| - Grad-CAM compatible con EfficientNetV2 anidado | |
| - Firma make_gradcam_heatmap simplificada | |
| - lifespan en lugar de on_event deprecado | |
| - Logs mejorados | |
| """ | |
| import os, io, base64, uuid, logging | |
| from typing import Optional | |
| from contextlib import asynccontextmanager | |
| import numpy as np | |
| from PIL import Image as PILImage | |
| import cv2 | |
| import tensorflow as tf | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("fruit-api") | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| CLASS_LABELS: list[str] = [ | |
| "Fresh_FreshApple", "Fresh_FreshBanana", "Fresh_FreshBellpepper", | |
| "Fresh_FreshBittergroud", "Fresh_FreshCapciscum", "Fresh_FreshCarrot", | |
| "Fresh_FreshCucumber", "Fresh_FreshMango", "Fresh_FreshOkara", | |
| "Fresh_FreshOrange", "Fresh_FreshPotato", "Fresh_FreshStrawberry", | |
| "Fresh_FreshTomato", "Rotten_RottenApple", "Rotten_RottenBanana", | |
| "Rotten_RottenBellpepper", "Rotten_RottenBittergroud", "Rotten_RottenCapsicum", | |
| "Rotten_RottenCarrot", "Rotten_RottenCucumber", "Rotten_RottenMango", | |
| "Rotten_RottenOkra", "Rotten_RottenOrange", "Rotten_RottenPotato", | |
| "Rotten_RottenStrawberry", "Rotten_RottenTomato", | |
| ] | |
| IMG_SIZE: tuple[int, int] = (224, 224) | |
| # --------------------------------------------------------------------------- | |
| # Model Loading | |
| # --------------------------------------------------------------------------- | |
| _model: Optional[tf.keras.Model] = None | |
| def load_model() -> tf.keras.Model: | |
| global _model | |
| if _model is not None: | |
| return _model | |
| model_path = os.getenv("MODEL_PATH", "./model/fruit_classifier.keras") | |
| hf_model_repo = os.getenv("HF_MODEL_REPO", "") | |
| if hf_model_repo and not os.path.exists(model_path): | |
| logger.info("Descargando modelo desde HuggingFace Hub: %s", hf_model_repo) | |
| from huggingface_hub import hf_hub_download | |
| model_path = hf_hub_download( | |
| repo_id=hf_model_repo, | |
| filename="fruit_classifier.keras", | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError( | |
| f"Modelo no encontrado en {model_path}. " | |
| "Descarga el .keras exportado desde Colab y ponlo en backend/model/" | |
| ) | |
| logger.info("Cargando modelo: %s", model_path) | |
| _model = tf.keras.models.load_model(model_path, compile=False) | |
| logger.info("Modelo cargado. Output shape: %s", _model.output_shape) | |
| return _model | |
| # --------------------------------------------------------------------------- | |
| # Grad-CAM — CORREGIDO para modelo anidado EfficientNetV2 | |
| # --------------------------------------------------------------------------- | |
| def _find_efficientnet_submodel(model: tf.keras.Model): | |
| """Extrae el sub-modelo EfficientNet del modelo principal.""" | |
| for layer in model.layers: | |
| if 'efficientnet' in layer.name.lower(): | |
| return layer | |
| return model.layers[1] | |
| def _find_last_conv_layer(model: tf.keras.Model): | |
| """Retorna (base_model, last_conv_layer_name) buscando dentro del backbone.""" | |
| base_model = _find_efficientnet_submodel(model) | |
| last_conv_name = None | |
| for layer in reversed(base_model.layers): | |
| if 'conv' in layer.name.lower(): | |
| last_conv_name = layer.name | |
| break | |
| if last_conv_name is None: | |
| raise ValueError("No se encontró capa Conv en el backbone.") | |
| return base_model, last_conv_name | |
| def make_gradcam_heatmap( | |
| img_array: np.ndarray, | |
| model: tf.keras.Model, | |
| pred_index: Optional[int] = None, | |
| ) -> np.ndarray: | |
| """ | |
| Grad-CAM compatible con EfficientNetV2 anidado. | |
| img_array: (1, 224, 224, 3), ya con preprocess_input aplicado. | |
| """ | |
| base_model, last_conv_name = _find_last_conv_layer(model) | |
| grad_submodel = tf.keras.Model( | |
| inputs=base_model.input, | |
| outputs=[ | |
| base_model.get_layer(last_conv_name).output, | |
| base_model.output, | |
| ] | |
| ) | |
| img_tensor = tf.cast(img_array, tf.float32) | |
| with tf.GradientTape() as tape: | |
| x = img_tensor | |
| for layer in model.layers: | |
| if layer == base_model: | |
| break | |
| if 'InputLayer' not in layer.__class__.__name__: | |
| x = layer(x) | |
| conv_outputs, base_outputs = grad_submodel(x) | |
| tape.watch(conv_outputs) | |
| base_idx = model.layers.index(base_model) | |
| x_top = base_outputs | |
| for layer in model.layers[base_idx + 1:]: | |
| x_top = layer(x_top) | |
| if pred_index is None: | |
| pred_index = int(tf.argmax(x_top[0])) | |
| class_channel = x_top[:, pred_index] | |
| grads = tape.gradient(class_channel, conv_outputs) | |
| if grads is None: | |
| raise RuntimeError("Grad-CAM: gradients are None.") | |
| pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
| heatmap = conv_outputs[0] @ pooled_grads[..., tf.newaxis] | |
| heatmap = tf.squeeze(heatmap) | |
| heatmap = tf.maximum(heatmap, 0) | |
| max_val = tf.reduce_max(heatmap) | |
| if max_val > 0: | |
| heatmap = heatmap / max_val | |
| return heatmap.numpy() | |
| def superimpose_heatmap( | |
| original_img: np.ndarray, | |
| heatmap: np.ndarray, | |
| alpha: float = 0.5, | |
| colormap: int = cv2.COLORMAP_JET, | |
| ) -> np.ndarray: | |
| heatmap_resized = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0])) | |
| heatmap_uint8 = np.uint8(255 * heatmap_resized) | |
| heatmap_color = cv2.applyColorMap(heatmap_uint8, colormap) | |
| return cv2.addWeighted(original_img, 1.0 - alpha, heatmap_color, alpha, 0) | |
| # --------------------------------------------------------------------------- | |
| # Image Helpers | |
| # --------------------------------------------------------------------------- | |
| def img_to_base64(img: np.ndarray, fmt: str = ".jpg") -> str: | |
| success, buffer = cv2.imencode(fmt, img) | |
| if not success: | |
| raise ValueError("No se pudo codificar la imagen.") | |
| return base64.b64encode(buffer).decode("utf-8") | |
| def preprocess_image(file_bytes: bytes) -> tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Preprocesa imagen para el modelo. | |
| SIEMPRE usa efficientnet_v2.preprocess_input — nunca /255. | |
| """ | |
| nparr = np.frombuffer(file_bytes, np.uint8) | |
| original_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if original_bgr is None: | |
| raise ValueError("Imagen inválida — no se pudo decodificar.") | |
| original_bgr = _resize_if_needed(original_bgr) | |
| rgb = cv2.cvtColor(original_bgr, cv2.COLOR_BGR2RGB) | |
| resized = cv2.resize(rgb, IMG_SIZE) | |
| model_input = tf.keras.applications.efficientnet_v2.preprocess_input( | |
| resized.astype(np.float32) | |
| ) | |
| model_input = np.expand_dims(model_input, axis=0) | |
| return original_bgr, model_input | |
| def _resize_if_needed(img: np.ndarray, max_dim: int = 1024) -> np.ndarray: | |
| h, w = img.shape[:2] | |
| if max(h, w) > max_dim: | |
| scale = max_dim / max(h, w) | |
| new_w, new_h = int(w * scale), int(h * scale) | |
| return cv2.resize(img, (new_w, new_h)) | |
| return img | |
| # --------------------------------------------------------------------------- | |
| # Lifespan (reemplaza @app.on_event deprecado) | |
| # --------------------------------------------------------------------------- | |
| async def lifespan(app: FastAPI): | |
| try: | |
| load_model() | |
| logger.info("Modelo pre-cargado en startup.") | |
| except Exception as e: | |
| logger.warning("Modelo no disponible en startup: %s", e) | |
| yield | |
| # --------------------------------------------------------------------------- | |
| # FastAPI App | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="Fruit Quality Classifier API v2", | |
| description="Clasificador de frutas con Grad-CAM para detección de descomposición", | |
| version="2.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| async def health(): | |
| try: | |
| model = load_model() | |
| return { | |
| "status": "ok", | |
| "model_loaded": True, | |
| "num_classes": model.output_shape[-1], | |
| "model_name": model.name, | |
| } | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=503, | |
| content={"status": "error", "model_loaded": False, "detail": str(e)}, | |
| ) | |
| async def predict(file: UploadFile = File(...)): | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException(400, "El archivo debe ser una imagen (JPEG, PNG, etc.).") | |
| contents = await file.read() | |
| try: | |
| original_bgr, model_input = preprocess_image(contents) | |
| except ValueError as e: | |
| raise HTTPException(400, str(e)) | |
| try: | |
| model = load_model() | |
| except Exception as e: | |
| raise HTTPException(503, f"Modelo no disponible: {e}") | |
| predictions = model.predict(model_input, verbose=0) | |
| preds_flat = predictions[0] | |
| pred_class_idx = int(np.argmax(preds_flat)) | |
| confidence = float(preds_flat[pred_class_idx]) | |
| class_name = CLASS_LABELS[pred_class_idx] | |
| original_b64 = img_to_base64(original_bgr) | |
| response: dict = { | |
| "class_name" : class_name, | |
| "confidence" : round(confidence, 6), | |
| "all_probabilities": [round(float(p), 6) for p in preds_flat], | |
| "image_base64" : original_b64, | |
| "is_fresh" : class_name.startswith("Fresh_"), | |
| } | |
| if not class_name.startswith("Fresh_"): | |
| try: | |
| heatmap = make_gradcam_heatmap(model_input, model, pred_class_idx) | |
| superimposed = superimpose_heatmap(original_bgr, heatmap) | |
| response["heatmap_base64"] = img_to_base64(superimposed) | |
| logger.info("Grad-CAM generado para clase: %s", class_name) | |
| except Exception as e: | |
| logger.warning("Grad-CAM falló para %s: %s", class_name, e) | |
| logger.info( | |
| "Predicción: %s (%.2f%%, fresh=%s)", | |
| class_name, confidence * 100, response["is_fresh"] | |
| ) | |
| return JSONResponse(content=response) | |
| async def feedback( | |
| file: UploadFile = File(...), | |
| correct_label: str = Form(...), | |
| ): | |
| if correct_label not in CLASS_LABELS: | |
| raise HTTPException( | |
| 400, f"Label inválido '{correct_label}'. Debe ser una de las 26 clases conocidas." | |
| ) | |
| hf_token = os.getenv("HF_TOKEN") | |
| repo_id = os.getenv("HF_DATASET_REPO") | |
| if not hf_token or not repo_id: | |
| raise HTTPException( | |
| 503, "Feedback no configurado. Define HF_TOKEN y HF_DATASET_REPO." | |
| ) | |
| contents = await file.read() | |
| ext = "jpg" | |
| if file.filename and "." in file.filename: | |
| ext = file.filename.rsplit(".", 1)[-1].lower() | |
| if ext not in ("jpg", "jpeg", "png"): | |
| ext = "jpg" | |
| image_id = str(uuid.uuid4()) | |
| path_in_repo = f"data/{correct_label}/{image_id}.{ext}" | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=hf_token) | |
| api.upload_file( | |
| path_or_fileobj=contents, | |
| path_in_repo=path_in_repo, | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| ) | |
| logger.info("Feedback subido: %s → %s", image_id, correct_label) | |
| except ImportError: | |
| raise HTTPException( | |
| 500, "huggingface_hub no instalado. Ejecuta: pip install huggingface_hub" | |
| ) | |
| except Exception as e: | |
| logger.error("Error subiendo feedback: %s", e) | |
| raise HTTPException(500, f"Fallo al subir feedback: {str(e)}") | |
| return {"status": "ok", "image_id": image_id, "label": correct_label} | |