OpScanIA / app.py
jorgeiv500's picture
Update app.py
173af48 verified
raw
history blame
12.4 kB
# ------------------------------------------------------------------------------------------------
import os, tempfile, traceback
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import spaces
from huggingface_hub import InferenceClient
# =========================
# Configuración del Chat remoto
# =========================
TX_MODEL_ID = os.getenv("TX_MODEL_ID", "mims-harvard/TxAgent-T1-Llama-3.1-8B")
HF_TOKEN = os.getenv("HF_TOKEN")
GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
# Cliente remoto del modelo.
# Clave: provider="featherless-ai", que es el que sí soporta este modelo en modo conversational.
tx_client = InferenceClient(
model=TX_MODEL_ID,
provider="featherless-ai",
token=HF_TOKEN,
timeout=60.0,
)
def _system_prompt():
return (
"Eres un asistente clínico educativo. NO sustituyes el juicio médico.\n"
"Usa CONTEXTO_OCR si existe; si falta, dilo explícitamente. "
"No inventes datos que no estén en el OCR ni hagas diagnósticos definitivos."
)
def _mk_messages_for_provider(ocr_md: str, ocr_txt: str, user_msg: str):
"""
Este formato es exactamente el que espera chat.completions.create():
lista de dicts con role: system/user/assistant.
"""
ctx = (ocr_md or "")[:3000] or (ocr_txt or "")[:3000]
sys_content = _system_prompt()
if ctx:
sys_content += (
"\n\n---\n"
"CONTEXTO_OCR (extraído de la imagen):\n"
f"{ctx}\n"
"---\n"
"Responde basándote en ese contenido. Si falta información, dilo."
)
if not user_msg:
user_msg = (
"Analiza el CONTEXTO_OCR anterior y explícame, en lenguaje claro, "
"qué medicamentos aparecen, dosis y advertencias importantes."
)
return [
{"role": "system", "content": sys_content},
{"role": "user", "content": user_msg},
]
def txagent_chat_remote(ocr_md: str, ocr_txt: str, user_msg: str) -> str:
"""
Llama a la tarea 'conversational' del provider featherless-ai para TxAgent.
Esto evita el error:
- text-generation no soportado
- 404 de hf-inference
"""
messages = _mk_messages_for_provider(ocr_md, ocr_txt, user_msg)
try:
completion = tx_client.chat.completions.create(
model=TX_MODEL_ID,
messages=messages,
max_tokens=GEN_MAX_NEW_TOKENS,
temperature=GEN_TEMPERATURE,
top_p=GEN_TOP_P,
stream=False,
)
# El objeto completion tiene .choices[i].message.content
return completion.choices[0].message.content
except Exception as e:
raise RuntimeError(f"Inference error: {e.__class__.__name__}: {e}")
# =========================
# OCR local — DeepSeek-OCR
# =========================
def _best_dtype():
if torch.cuda.is_available():
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
return torch.float32
def _load_ocr_model():
"""
Cargamos DeepSeek-OCR con trust_remote_code.
IMPORTANTE: no movemos a CUDA aquí. Eso solo ocurre en el worker @spaces.GPU.
"""
model_id = "deepseek-ai/DeepSeek-OCR"
revision = os.getenv("OCR_REVISION", None) # pin commit para evitar que cambie el repo
attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")
ocr_tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
revision=revision,
)
try:
ocr_model = AutoModel.from_pretrained(
model_id,
trust_remote_code=True,
use_safetensors=True,
_attn_implementation=attn_impl,
revision=revision,
).eval()
return ocr_tokenizer, ocr_model
except Exception as e:
# fallback sin FlashAttention2
if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
ocr_model = AutoModel.from_pretrained(
model_id,
trust_remote_code=True,
use_safetensors=True,
_attn_implementation="eager",
revision=revision,
).eval()
return ocr_tokenizer, ocr_model
raise
OCR_TOKENIZER, OCR_MODEL = _load_ocr_model()
@spaces.GPU # <- ÚNICO sitio donde tocamos CUDA. Cumple con la política de Spaces Zero.
def ocr_infer(image: Image.Image, model_size: str, task_type: str, is_eval_mode: bool):
"""
Ejecuta OCR en GPU (si hay) y devuelve:
- imagen anotada (puede ser None en eval_mode)
- markdown OCR
- texto llano OCR
"""
if image is None:
return None, "Sube una imagen primero.", "Sube una imagen primero."
dtype = _best_dtype()
model_local = OCR_MODEL.cuda().to(dtype) if torch.cuda.is_available() else OCR_MODEL.to(dtype)
with tempfile.TemporaryDirectory() as outdir:
# prompt según modo
if task_type == "Free OCR":
prompt = "<image>\nFree OCR. "
else:
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
size_cfgs = {
"Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
"Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
"Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
"Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
"Gundam (Recommended)": {
"base_size": 1024,
"image_size": 640,
"crop_mode": True
},
}
cfg = size_cfgs.get(model_size, size_cfgs["Gundam (Recommended)"])
tmp_path = os.path.join(outdir, "tmp.jpg")
image.save(tmp_path)
plain_text_result = model_local.infer(
OCR_TOKENIZER,
prompt=prompt,
image_file=tmp_path,
output_path=outdir,
base_size=cfg["base_size"],
image_size=cfg["image_size"],
crop_mode=cfg["crop_mode"],
save_results=True,
test_compress=True,
eval_mode=is_eval_mode,
)
img_boxes_path = os.path.join(outdir, "result_with_boxes.jpg")
md_path = os.path.join(outdir, "result.mmd")
markdown_content = (
"Markdown result was not generated. This is expected for 'Free OCR' task."
)
if os.path.exists(md_path):
with open(md_path, "r", encoding="utf-8") as f:
markdown_content = f.read()
annotated_img = None
if os.path.exists(img_boxes_path):
annotated_img = Image.open(img_boxes_path)
annotated_img.load()
text_out = plain_text_result if plain_text_result else markdown_content
return annotated_img, markdown_content, text_out
# =========================
# Estados / helpers para la UI
# =========================
def ocr_snapshot(md_text: str, plain_text: str):
"""
Guardamos el OCR en estados (para enviarlo al chat después)
y devolvemos esas vistas rápidas.
"""
return md_text, plain_text, md_text, plain_text
def chat_reply(user_msg, chat_state, ocr_md_state, ocr_txt_state):
"""
Lógica del botón "Enviar" en el chat.
"""
try:
answer = txagent_chat_remote(
ocr_md_state or "",
ocr_txt_state or "",
user_msg or ""
)
updated = (chat_state or []) + [
{"role": "user", "content": user_msg or "(solo OCR)"},
{"role": "assistant", "content": answer},
]
return updated, "", ""
except Exception as e:
tb = traceback.format_exc(limit=2)
updated = (chat_state or []) + [
{"role": "user", "content": user_msg or ""},
{
"role": "assistant",
"content": f"⚠️ Error remoto (chat): {e.__class__.__name__}: {e}",
},
]
return updated, "", f"{e}\n{tb}"
def clear_chat():
return [], "", ""
# =========================
# UI en Gradio 5
# =========================
with gr.Blocks(
title="OpScanIA",
theme=gr.themes.Soft()
) as demo:
gr.Markdown(
"""
# 📄 DeepSeek-OCR → 💬 Chat Clínico
1. **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto).
2. El chat usa automáticamente el texto detectado por OCR
como contexto clínico.
⚠ Uso educativo. No reemplaza consejo médico profesional.
"""
)
# Estados para pasar OCR -> Chat
ocr_md_state = gr.State("")
ocr_txt_state = gr.State("")
with gr.Row():
# Panel OCR
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Upload Image",
sources=["upload", "clipboard", "webcam"]
)
model_size = gr.Dropdown(
choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
value="Gundam (Recommended)",
label="OCR Model Size"
)
task_type = gr.Dropdown(
choices=["Free OCR", "Convert to Markdown"],
value="Convert to Markdown",
label="OCR Task"
)
eval_mode_checkbox = gr.Checkbox(
value=True,
label="Evaluation mode (más rápido)",
info="Puede omitir imagen anotada y concentrarse en el texto."
)
submit_btn = gr.Button("Process Image", variant="primary")
# Resultados OCR
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("Annotated Image"):
output_image = gr.Image(interactive=False)
with gr.TabItem("Markdown Preview"):
output_markdown = gr.Markdown()
with gr.TabItem("Markdown / OCR Text"):
output_text = gr.Textbox(
lines=18,
show_copy_button=True,
interactive=False
)
with gr.Row():
md_preview = gr.Textbox(
label="Snapshot Markdown OCR",
lines=8,
interactive=False
)
txt_preview = gr.Textbox(
label="Snapshot Texto OCR",
lines=8,
interactive=False
)
# Panel Chat
gr.Markdown("## Chat Clínico")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Asistente OCR (TxAgent remoto)",
type="messages",
height=420,
)
user_in = gr.Textbox(
label="Mensaje",
placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)",
lines=2,
)
with gr.Row():
send_btn = gr.Button("Enviar", variant="primary")
clear_btn = gr.Button("Limpiar")
with gr.Column(scale=1):
error_box = gr.Textbox(
label="Debug (si hay error)",
lines=8,
interactive=False
)
# Wiring OCR
submit_btn.click(
fn=ocr_infer,
inputs=[image_input, model_size, task_type, eval_mode_checkbox],
outputs=[output_image, output_markdown, output_text],
).then(
fn=ocr_snapshot,
inputs=[output_markdown, output_text],
outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
)
# Wiring Chat
send_btn.click(
fn=chat_reply,
inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
outputs=[chatbot, user_in, error_box],
)
clear_btn.click(
fn=clear_chat,
outputs=[chatbot, user_in, error_box],
)
if __name__ == "__main__":
# Gradio 5: sin concurrency_count en queue()
# demo.queue(max_size=32) # opcional si quieres limitar cola
demo.launch()