# ------------------------------------------------------------------------------------------------ import os, tempfile, traceback import gradio as gr import torch from PIL import Image from transformers import AutoModel, AutoTokenizer import spaces from huggingface_hub import InferenceClient # ========================= # Configuración del Chat remoto # ========================= TX_MODEL_ID = os.getenv("TX_MODEL_ID", "mims-harvard/TxAgent-T1-Llama-3.1-8B") HF_TOKEN = os.getenv("HF_TOKEN") GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512")) GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2")) GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9")) # Cliente remoto del modelo. # Clave: provider="featherless-ai", que es el que sí soporta este modelo en modo conversational. tx_client = InferenceClient( model=TX_MODEL_ID, provider="featherless-ai", token=HF_TOKEN, timeout=60.0, ) def _system_prompt(): return ( "Eres un asistente clínico educativo. NO sustituyes el juicio médico.\n" "Usa CONTEXTO_OCR si existe; si falta, dilo explícitamente. " "No inventes datos que no estén en el OCR ni hagas diagnósticos definitivos." ) def _mk_messages_for_provider(ocr_md: str, ocr_txt: str, user_msg: str): """ Este formato es exactamente el que espera chat.completions.create(): lista de dicts con role: system/user/assistant. """ ctx = (ocr_md or "")[:3000] or (ocr_txt or "")[:3000] sys_content = _system_prompt() if ctx: sys_content += ( "\n\n---\n" "CONTEXTO_OCR (extraído de la imagen):\n" f"{ctx}\n" "---\n" "Responde basándote en ese contenido. Si falta información, dilo." ) if not user_msg: user_msg = ( "Analiza el CONTEXTO_OCR anterior y explícame, en lenguaje claro, " "qué medicamentos aparecen, dosis y advertencias importantes." ) return [ {"role": "system", "content": sys_content}, {"role": "user", "content": user_msg}, ] def txagent_chat_remote(ocr_md: str, ocr_txt: str, user_msg: str) -> str: """ Llama a la tarea 'conversational' del provider featherless-ai para TxAgent. Esto evita el error: - text-generation no soportado - 404 de hf-inference """ messages = _mk_messages_for_provider(ocr_md, ocr_txt, user_msg) try: completion = tx_client.chat.completions.create( model=TX_MODEL_ID, messages=messages, max_tokens=GEN_MAX_NEW_TOKENS, temperature=GEN_TEMPERATURE, top_p=GEN_TOP_P, stream=False, ) # El objeto completion tiene .choices[i].message.content return completion.choices[0].message.content except Exception as e: raise RuntimeError(f"Inference error: {e.__class__.__name__}: {e}") # ========================= # OCR local — DeepSeek-OCR # ========================= def _best_dtype(): if torch.cuda.is_available(): return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 return torch.float32 def _load_ocr_model(): """ Cargamos DeepSeek-OCR con trust_remote_code. IMPORTANTE: no movemos a CUDA aquí. Eso solo ocurre en el worker @spaces.GPU. """ model_id = "deepseek-ai/DeepSeek-OCR" revision = os.getenv("OCR_REVISION", None) # pin commit para evitar que cambie el repo attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2") ocr_tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, revision=revision, ) try: ocr_model = AutoModel.from_pretrained( model_id, trust_remote_code=True, use_safetensors=True, _attn_implementation=attn_impl, revision=revision, ).eval() return ocr_tokenizer, ocr_model except Exception as e: # fallback sin FlashAttention2 if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]): ocr_model = AutoModel.from_pretrained( model_id, trust_remote_code=True, use_safetensors=True, _attn_implementation="eager", revision=revision, ).eval() return ocr_tokenizer, ocr_model raise OCR_TOKENIZER, OCR_MODEL = _load_ocr_model() @spaces.GPU # <- ÚNICO sitio donde tocamos CUDA. Cumple con la política de Spaces Zero. def ocr_infer(image: Image.Image, model_size: str, task_type: str, is_eval_mode: bool): """ Ejecuta OCR en GPU (si hay) y devuelve: - imagen anotada (puede ser None en eval_mode) - markdown OCR - texto llano OCR """ if image is None: return None, "Sube una imagen primero.", "Sube una imagen primero." dtype = _best_dtype() model_local = OCR_MODEL.cuda().to(dtype) if torch.cuda.is_available() else OCR_MODEL.to(dtype) with tempfile.TemporaryDirectory() as outdir: # prompt según modo if task_type == "Free OCR": prompt = "\nFree OCR. " else: prompt = "\n<|grounding|>Convert the document to markdown. " size_cfgs = { "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, "Gundam (Recommended)": { "base_size": 1024, "image_size": 640, "crop_mode": True }, } cfg = size_cfgs.get(model_size, size_cfgs["Gundam (Recommended)"]) tmp_path = os.path.join(outdir, "tmp.jpg") image.save(tmp_path) plain_text_result = model_local.infer( OCR_TOKENIZER, prompt=prompt, image_file=tmp_path, output_path=outdir, base_size=cfg["base_size"], image_size=cfg["image_size"], crop_mode=cfg["crop_mode"], save_results=True, test_compress=True, eval_mode=is_eval_mode, ) img_boxes_path = os.path.join(outdir, "result_with_boxes.jpg") md_path = os.path.join(outdir, "result.mmd") markdown_content = ( "Markdown result was not generated. This is expected for 'Free OCR' task." ) if os.path.exists(md_path): with open(md_path, "r", encoding="utf-8") as f: markdown_content = f.read() annotated_img = None if os.path.exists(img_boxes_path): annotated_img = Image.open(img_boxes_path) annotated_img.load() text_out = plain_text_result if plain_text_result else markdown_content return annotated_img, markdown_content, text_out # ========================= # Estados / helpers para la UI # ========================= def ocr_snapshot(md_text: str, plain_text: str): """ Guardamos el OCR en estados (para enviarlo al chat después) y devolvemos esas vistas rápidas. """ return md_text, plain_text, md_text, plain_text def chat_reply(user_msg, chat_state, ocr_md_state, ocr_txt_state): """ Lógica del botón "Enviar" en el chat. """ try: answer = txagent_chat_remote( ocr_md_state or "", ocr_txt_state or "", user_msg or "" ) updated = (chat_state or []) + [ {"role": "user", "content": user_msg or "(solo OCR)"}, {"role": "assistant", "content": answer}, ] return updated, "", "" except Exception as e: tb = traceback.format_exc(limit=2) updated = (chat_state or []) + [ {"role": "user", "content": user_msg or ""}, { "role": "assistant", "content": f"⚠️ Error remoto (chat): {e.__class__.__name__}: {e}", }, ] return updated, "", f"{e}\n{tb}" def clear_chat(): return [], "", "" # ========================= # UI en Gradio 5 # ========================= with gr.Blocks( title="OpScanIA", theme=gr.themes.Soft() ) as demo: gr.Markdown( """ # 📄 DeepSeek-OCR → 💬 Chat Clínico 1. **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto). 2. El chat usa automáticamente el texto detectado por OCR como contexto clínico. ⚠ Uso educativo. No reemplaza consejo médico profesional. """ ) # Estados para pasar OCR -> Chat ocr_md_state = gr.State("") ocr_txt_state = gr.State("") with gr.Row(): # Panel OCR with gr.Column(scale=1): image_input = gr.Image( type="pil", label="Upload Image", sources=["upload", "clipboard", "webcam"] ) model_size = gr.Dropdown( choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Gundam (Recommended)", label="OCR Model Size" ) task_type = gr.Dropdown( choices=["Free OCR", "Convert to Markdown"], value="Convert to Markdown", label="OCR Task" ) eval_mode_checkbox = gr.Checkbox( value=True, label="Evaluation mode (más rápido)", info="Puede omitir imagen anotada y concentrarse en el texto." ) submit_btn = gr.Button("Process Image", variant="primary") # Resultados OCR with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("Annotated Image"): output_image = gr.Image(interactive=False) with gr.TabItem("Markdown Preview"): output_markdown = gr.Markdown() with gr.TabItem("Markdown / OCR Text"): output_text = gr.Textbox( lines=18, show_copy_button=True, interactive=False ) with gr.Row(): md_preview = gr.Textbox( label="Snapshot Markdown OCR", lines=8, interactive=False ) txt_preview = gr.Textbox( label="Snapshot Texto OCR", lines=8, interactive=False ) # Panel Chat gr.Markdown("## Chat Clínico") with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot( label="Asistente OCR (TxAgent remoto)", type="messages", height=420, ) user_in = gr.Textbox( label="Mensaje", placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)", lines=2, ) with gr.Row(): send_btn = gr.Button("Enviar", variant="primary") clear_btn = gr.Button("Limpiar") with gr.Column(scale=1): error_box = gr.Textbox( label="Debug (si hay error)", lines=8, interactive=False ) # Wiring OCR submit_btn.click( fn=ocr_infer, inputs=[image_input, model_size, task_type, eval_mode_checkbox], outputs=[output_image, output_markdown, output_text], ).then( fn=ocr_snapshot, inputs=[output_markdown, output_text], outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview], ) # Wiring Chat send_btn.click( fn=chat_reply, inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state], outputs=[chatbot, user_in, error_box], ) clear_btn.click( fn=clear_chat, outputs=[chatbot, user_in, error_box], ) if __name__ == "__main__": # Gradio 5: sin concurrency_count en queue() # demo.queue(max_size=32) # opcional si quieres limitar cola demo.launch()