import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_NAME)
model = ViTForImageClassification.from_pretrained(MODEL_NAME)
model.eval()
CLASSES = [
"Queratosis actínica / Bowen", # 0
"Carcinoma células basales", # 1
"Lesión queratósica benigna", # 2
"Dermatofibroma", # 3
"Melanoma maligno", # 4
"Nevus melanocítico", # 5
"Lesión vascular" # 6
]
RISK_LEVELS = {
0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
}
def analizar_lesion_vit_web(img):
inputs = feature_extractor(img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
pred_idx = int(np.argmax(probs))
pred_clase = CLASSES[pred_idx]
confianza = probs[pred_idx]
cancer_risk_score = sum(probs[i] * RISK_LEVELS[i]['weight'] for i in range(7))
melanoma_risk = probs[4]
bcc_risk = probs[1]
precancer_risk = probs[0]
benign_total = sum(probs[i] for i in [2,3,5,6])
colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
fig, ax = plt.subplots(figsize=(8,3))
ax.bar(CLASSES, probs*100, color=colors_bars)
ax.set_title("Probabilidad por tipo de lesión")
ax.set_ylabel("Probabilidad (%)")
ax.set_xticklabels(CLASSES, rotation=45, ha='right')
ax.grid(axis='y', alpha=0.2)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close(fig)
buf.seek(0)
img_bytes = buf.getvalue()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
html_chart = f''
urgency = ""
recommendation = ""
timeframe = ""
if cancer_risk_score > 0.6:
urgency = "🚨 CRÍTICO"
recommendation = "Derivación INMEDIATA a oncología dermatológica"
timeframe = "En 24-48 horas máximo"
elif cancer_risk_score > 0.4:
urgency = "⚠️ ALTO RIESGO"
recommendation = "Consulta urgente con dermatólogo especialista"
timeframe = "En 1 semana máximo"
elif cancer_risk_score > 0.2:
urgency = "📋 RIESGO MODERADO"
recommendation = "Evaluación dermatológica programada"
timeframe = "En 2-4 semanas"
else:
urgency = "✅ BAJO RIESGO"
recommendation = "Seguimiento de rutina"
timeframe = "En 3-6 meses"
informe = f"""
{bar_visual}