OpScanIA / app.py
jorgeiv500's picture
xx
2c7042c
raw
history blame
13.6 kB
# app.py — DeepSeek-OCR + DeepSeek-R1 Medical Mini (remoto HF o local GGUF) — Gradio 5
import os, tempfile, traceback
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import spaces
from huggingface_hub import hf_hub_download, InferenceClient
from llama_cpp import Llama
# ===============================================================
# Configuración LLM (CHAT) — DeepSeek-R1 Medical Mini
# - Remoto (HF Inference): R1_REMOTE=1 y (opcional) R1_MODEL_ID, HF_TOKEN
# - Local GGUF (CPU/Zero): R1_REMOTE=0 y GGUF_REPO / GGUF_FILE
# ===============================================================
R1_REMOTE = os.getenv("R1_REMOTE", "0") == "1"
R1_MODEL_ID = os.getenv("R1_MODEL_ID", "Mouhib007/DeepSeek-r1-Medical-Mini")
HF_TOKEN = os.getenv("HF_TOKEN") # público -> puede ser None
# ---- Local GGUF (fallback / modo offline) ----
GGUF_CANDIDATES = []
ENV_REPO = os.getenv("GGUF_REPO", "").strip()
ENV_FILE = os.getenv("GGUF_FILE", "").strip()
if ENV_REPO and ENV_FILE:
GGUF_CANDIDATES.append((ENV_REPO, ENV_FILE))
# Candidato por defecto (ajústalo si usas otro)
GGUF_CANDIDATES.append((
"mradermacher/DeepSeek-r1-Medical-Mini-GGUF",
"DeepSeek-r1-Medical-Mini.f16.gguf"
))
N_CTX = int(os.getenv("N_CTX", "2048"))
N_THREADS = int(os.getenv("N_THREADS", str(os.cpu_count() or 4)))
N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0"))
N_BATCH = int(os.getenv("N_BATCH", "96"))
# ---- Cliente remoto (HF Inference) ----
_remote_client = None
def get_remote_client():
global _remote_client
if _remote_client is None:
_remote_client = InferenceClient(model=R1_MODEL_ID, token=HF_TOKEN, timeout=60)
return _remote_client
# ---- Formato ChatML (compatible con DeepSeek/Qwen) ----
def _format_chatml(messages):
parts = []
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
parts.append(f"<|im_start|>{role}\n{content}<|im_end|>\n")
parts.append("<|im_start|>assistant\n")
return "".join(parts)
def r1_chat(messages, temperature=0.2, max_tokens=384):
"""Remoto (HF) o local (llama-cpp) para DeepSeek-R1 Medical Mini."""
if R1_REMOTE:
client = get_remote_client()
try:
# Algunos endpoints soportan chat_completion
resp = client.chat_completion(messages=messages, temperature=temperature, max_tokens=max_tokens)
return resp.choices[0].message["content"]
except Exception:
# Fallback universal a text_generation con ChatML
try:
prompt = _format_chatml(messages)
return client.text_generation(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
stop_sequences=["<|im_end|>"],
stream=False,
)
except Exception:
# Si remoto falla (401/429/etc), caemos a local si hay GGUF
pass
# Local GGUF
llm = get_llm()
out = llm.create_chat_completion(messages=messages, temperature=temperature, max_tokens=max_tokens)
return out["choices"][0]["message"]["content"]
# ---- Loader local (GGUF) ----
_llm = None
def _download_gguf():
last_err = None
for repo, fname in GGUF_CANDIDATES:
try:
return hf_hub_download(repo_id=repo, filename=fname), repo, fname
except Exception as e:
last_err = e
raise RuntimeError(f"No se pudo descargar ningún GGUF. Último error: {last_err}")
def get_llm():
global _llm
if _llm is not None:
return _llm
gguf_path, _, _ = _download_gguf()
_llm = Llama(
model_path=gguf_path,
# No forzamos chat_format; usamos el del GGUF del R1
n_ctx=N_CTX,
n_threads=N_THREADS,
n_gpu_layers=N_GPU_LAYERS,
n_batch=N_BATCH,
verbose=False,
)
return _llm
# Warmup opcional (para no esperar en el primer mensaje si usas local)
if os.getenv("WARMUP", "0") == "1" and not R1_REMOTE:
try:
get_llm()
except Exception:
pass
# ===============================================================
# DeepSeek-OCR (INTACTO — con fallback si no hay FlashAttention2)
# ===============================================================
def _best_dtype():
if torch.cuda.is_available():
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
return torch.float32
def _load_ocr_model():
model_name = "deepseek-ai/DeepSeek-OCR"
ocr_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2") # por defecto igual que antes
try:
ocr_model = AutoModel.from_pretrained(
model_name,
_attn_implementation=attn_impl,
trust_remote_code=True,
use_safetensors=True,
).eval()
return ocr_tokenizer, ocr_model
except Exception as e:
# Si falla por FlashAttention2, reintenta en modo "eager" (CPU/compat)
msg = str(e)
if "flash_attn" in msg or "FlashAttention2" in msg or "flash_attention_2" in msg:
ocr_model = AutoModel.from_pretrained(
model_name,
_attn_implementation="eager",
trust_remote_code=True,
use_safetensors=True,
).eval()
return ocr_tokenizer, ocr_model
raise
tokenizer, model = _load_ocr_model()
@spaces.GPU
def process_image(image, model_size, task_type, is_eval_mode):
"""
Devuelve: imagen anotada, markdown y texto (o markdown si no hay texto).
"""
if image is None:
return None, "Please upload an image first.", "Please upload an image first."
dtype = _best_dtype()
model_device = model.cuda().to(dtype) if torch.cuda.is_available() else model.to(dtype)
with tempfile.TemporaryDirectory() as output_path:
if task_type == "Free OCR":
prompt = "<image>\nFree OCR. "
elif task_type == "Convert to Markdown":
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
else:
prompt = "<image>\nFree OCR. "
temp_image_path = os.path.join(output_path, "temp_image.jpg")
image.save(temp_image_path)
size_configs = {
"Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
"Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
"Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
"Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
"Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True},
}
config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
plain_text_result = model_device.infer(
tokenizer,
prompt=prompt,
image_file=temp_image_path,
output_path=output_path,
base_size=config["base_size"],
image_size=config["image_size"],
crop_mode=config["crop_mode"],
save_results=True,
test_compress=True,
eval_mode=is_eval_mode,
)
image_result_path = os.path.join(output_path, "result_with_boxes.jpg")
markdown_result_path = os.path.join(output_path, "result.mmd")
if os.path.exists(markdown_result_path):
with open(markdown_result_path, "r", encoding="utf-8") as f:
markdown_content = f.read()
else:
markdown_content = "Markdown result was not generated. This is expected for 'Free OCR' task."
result_image = None
if os.path.exists(image_result_path):
result_image = Image.open(image_result_path)
result_image.load()
text_result = plain_text_result if plain_text_result else markdown_content
return result_image, markdown_content, text_result
# ===============================================================
# Chat (inyecta OCR en el primer system) — usando R1
# ===============================================================
def _truncate(text, max_chars=3000):
return (text or "")[:max_chars]
def _system_prompt():
return (
"Eres un asistente clínico educativo. No sustituyes el juicio médico. "
"Usa CONTEXTO_OCR si existe; si falta, pídelo. Evita diagnósticos definitivos."
)
def _ocr_context(ocr_md, ocr_txt):
return _truncate(ocr_md) or _truncate(ocr_txt) or ""
def to_chat_messages(chat_msgs, ocr_md, ocr_txt):
sys = _system_prompt()
ctx = _ocr_context(ocr_md, ocr_txt)
if ctx:
sys += (
"\n\n---\n"
"CONTEXTO_OCR (fuente principal; si falta un dato, dilo explícitamente):\n"
f"{ctx}\n---"
)
msgs = [{"role": "system", "content": sys}]
for m in (chat_msgs or []):
if m.get("role") in ("user", "assistant"):
msgs.append({"role": m["role"], "content": m.get("content", "")})
return msgs
def r1_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
if not user_msg:
user_msg = "Analiza el CONTEXTO_OCR anterior y responde a partir de ese contenido."
try:
msgs = to_chat_messages(chat_msgs, ocr_md, ocr_txt) + [{"role": "user", "content": user_msg}]
answer = r1_chat(msgs, temperature=0.2, max_tokens=512)
updated = (chat_msgs or []) + [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": answer},
]
return updated, "", gr.update(value="")
except Exception as e:
err = f"{e.__class__.__name__}: {str(e) or repr(e)}"
tb = traceback.format_exc(limit=2)
updated = (chat_msgs or []) + [
{"role": "user", "content": user_msg or ""},
{"role": "assistant", "content": f"⚠️ Error LLM: {err}"},
]
return updated, "", gr.update(value=f"{err}\n{tb}")
def clear_chat():
return [], "", gr.update(value="")
# ===============================================================
# UI (Gradio 5)
# ===============================================================
with gr.Blocks(title="DeepSeek-OCR + DeepSeek-R1 Medical Mini", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# DeepSeek-OCR → Chat Médico con **DeepSeek-R1 Medical Mini** (remoto HF o local GGUF)
1) **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto).
2) **Chatea** con **DeepSeek-R1 Medical Mini** usando automáticamente el **OCR** como contexto.
*Uso educativo; no reemplaza consejo médico.*
"""
)
ocr_md_state = gr.State("")
ocr_txt_state = gr.State("")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard", "webcam"])
model_size = gr.Dropdown(
choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
value="Gundam (Recommended)", label="Model Size",
)
task_type = gr.Dropdown(
choices=["Free OCR", "Convert to Markdown"],
value="Convert to Markdown", label="Task Type",
)
eval_mode_checkbox = gr.Checkbox(
value=False, label="Enable Evaluation Mode",
info="Solo texto (más rápido). Desmárcalo para ver imagen anotada y markdown.",
)
submit_btn = gr.Button("Process Image", variant="primary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("Annotated Image"): output_image = gr.Image(interactive=False)
with gr.TabItem("Markdown Preview"): output_markdown = gr.Markdown()
with gr.TabItem("Markdown Source (or Eval Output)"):
output_text = gr.Textbox(lines=18, show_copy_button=True, interactive=False)
with gr.Row():
md_preview = gr.Textbox(label="Snapshot Markdown OCR", lines=10, interactive=False)
txt_preview = gr.Textbox(label="Snapshot Texto OCR", lines=10, interactive=False)
gr.Markdown("## Chat Clínico (DeepSeek-R1 Medical Mini)")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Asistente OCR (R1 Medical Mini)", type="messages", height=420)
user_in = gr.Textbox(label="Mensaje", placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)", lines=2)
with gr.Row():
send_btn = gr.Button("Enviar", variant="primary")
clear_btn = gr.Button("Limpiar")
with gr.Column(scale=1):
error_box = gr.Textbox(label="Debug (si hay error)", lines=8, interactive=False)
# OCR → outputs y estados
submit_btn.click(
fn=process_image,
inputs=[image_input, model_size, task_type, eval_mode_checkbox],
outputs=[output_image, output_markdown, output_text],
).then(
fn=lambda md, tx: (md, tx, md, tx),
inputs=[output_markdown, output_text],
outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
)
# Chat
send_btn.click(
fn=r1_reply,
inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
outputs=[chatbot, user_in, error_box],
)
clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, error_box])
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch()