Spaces:
Sleeping
Sleeping
| # ------------------------------------------------------------------------------------------------ | |
| import os, tempfile, traceback | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModel, AutoTokenizer | |
| import spaces | |
| from huggingface_hub import InferenceClient | |
| # ========================= | |
| # Configuración del Chat remoto | |
| # ========================= | |
| TX_MODEL_ID = os.getenv("TX_MODEL_ID", "mims-harvard/TxAgent-T1-Llama-3.1-8B") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512")) | |
| GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2")) | |
| GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9")) | |
| # Cliente remoto del modelo. | |
| # Clave: provider="featherless-ai", que es el que sí soporta este modelo en modo conversational. | |
| tx_client = InferenceClient( | |
| model=TX_MODEL_ID, | |
| provider="featherless-ai", | |
| token=HF_TOKEN, | |
| timeout=60.0, | |
| ) | |
| def _system_prompt(): | |
| return ( | |
| "Eres un asistente clínico educativo. NO sustituyes el juicio médico.\n" | |
| "Usa CONTEXTO_OCR si existe; si falta, dilo explícitamente. " | |
| "No inventes datos que no estén en el OCR ni hagas diagnósticos definitivos." | |
| ) | |
| def _mk_messages_for_provider(ocr_md: str, ocr_txt: str, user_msg: str): | |
| """ | |
| Este formato es exactamente el que espera chat.completions.create(): | |
| lista de dicts con role: system/user/assistant. | |
| """ | |
| ctx = (ocr_md or "")[:3000] or (ocr_txt or "")[:3000] | |
| sys_content = _system_prompt() | |
| if ctx: | |
| sys_content += ( | |
| "\n\n---\n" | |
| "CONTEXTO_OCR (extraído de la imagen):\n" | |
| f"{ctx}\n" | |
| "---\n" | |
| "Responde basándote en ese contenido. Si falta información, dilo." | |
| ) | |
| if not user_msg: | |
| user_msg = ( | |
| "Analiza el CONTEXTO_OCR anterior y explícame, en lenguaje claro, " | |
| "qué medicamentos aparecen, dosis y advertencias importantes." | |
| ) | |
| return [ | |
| {"role": "system", "content": sys_content}, | |
| {"role": "user", "content": user_msg}, | |
| ] | |
| def txagent_chat_remote(ocr_md: str, ocr_txt: str, user_msg: str) -> str: | |
| """ | |
| Llama a la tarea 'conversational' del provider featherless-ai para TxAgent. | |
| Esto evita el error: | |
| - text-generation no soportado | |
| - 404 de hf-inference | |
| """ | |
| messages = _mk_messages_for_provider(ocr_md, ocr_txt, user_msg) | |
| try: | |
| completion = tx_client.chat.completions.create( | |
| model=TX_MODEL_ID, | |
| messages=messages, | |
| max_tokens=GEN_MAX_NEW_TOKENS, | |
| temperature=GEN_TEMPERATURE, | |
| top_p=GEN_TOP_P, | |
| stream=False, | |
| ) | |
| # El objeto completion tiene .choices[i].message.content | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| raise RuntimeError(f"Inference error: {e.__class__.__name__}: {e}") | |
| # ========================= | |
| # OCR local — DeepSeek-OCR | |
| # ========================= | |
| def _best_dtype(): | |
| if torch.cuda.is_available(): | |
| return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| return torch.float32 | |
| def _load_ocr_model(): | |
| """ | |
| Cargamos DeepSeek-OCR con trust_remote_code. | |
| IMPORTANTE: no movemos a CUDA aquí. Eso solo ocurre en el worker @spaces.GPU. | |
| """ | |
| model_id = "deepseek-ai/DeepSeek-OCR" | |
| revision = os.getenv("OCR_REVISION", None) # pin commit para evitar que cambie el repo | |
| attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2") | |
| ocr_tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| revision=revision, | |
| ) | |
| try: | |
| ocr_model = AutoModel.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| use_safetensors=True, | |
| _attn_implementation=attn_impl, | |
| revision=revision, | |
| ).eval() | |
| return ocr_tokenizer, ocr_model | |
| except Exception as e: | |
| # fallback sin FlashAttention2 | |
| if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]): | |
| ocr_model = AutoModel.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| use_safetensors=True, | |
| _attn_implementation="eager", | |
| revision=revision, | |
| ).eval() | |
| return ocr_tokenizer, ocr_model | |
| raise | |
| OCR_TOKENIZER, OCR_MODEL = _load_ocr_model() | |
| # <- ÚNICO sitio donde tocamos CUDA. Cumple con la política de Spaces Zero. | |
| def ocr_infer(image: Image.Image, model_size: str, task_type: str, is_eval_mode: bool): | |
| """ | |
| Ejecuta OCR en GPU (si hay) y devuelve: | |
| - imagen anotada (puede ser None en eval_mode) | |
| - markdown OCR | |
| - texto llano OCR | |
| """ | |
| if image is None: | |
| return None, "Sube una imagen primero.", "Sube una imagen primero." | |
| dtype = _best_dtype() | |
| model_local = OCR_MODEL.cuda().to(dtype) if torch.cuda.is_available() else OCR_MODEL.to(dtype) | |
| with tempfile.TemporaryDirectory() as outdir: | |
| # prompt según modo | |
| if task_type == "Free OCR": | |
| prompt = "<image>\nFree OCR. " | |
| else: | |
| prompt = "<image>\n<|grounding|>Convert the document to markdown. " | |
| size_cfgs = { | |
| "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, | |
| "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, | |
| "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, | |
| "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, | |
| "Gundam (Recommended)": { | |
| "base_size": 1024, | |
| "image_size": 640, | |
| "crop_mode": True | |
| }, | |
| } | |
| cfg = size_cfgs.get(model_size, size_cfgs["Gundam (Recommended)"]) | |
| tmp_path = os.path.join(outdir, "tmp.jpg") | |
| image.save(tmp_path) | |
| plain_text_result = model_local.infer( | |
| OCR_TOKENIZER, | |
| prompt=prompt, | |
| image_file=tmp_path, | |
| output_path=outdir, | |
| base_size=cfg["base_size"], | |
| image_size=cfg["image_size"], | |
| crop_mode=cfg["crop_mode"], | |
| save_results=True, | |
| test_compress=True, | |
| eval_mode=is_eval_mode, | |
| ) | |
| img_boxes_path = os.path.join(outdir, "result_with_boxes.jpg") | |
| md_path = os.path.join(outdir, "result.mmd") | |
| markdown_content = ( | |
| "Markdown result was not generated. This is expected for 'Free OCR' task." | |
| ) | |
| if os.path.exists(md_path): | |
| with open(md_path, "r", encoding="utf-8") as f: | |
| markdown_content = f.read() | |
| annotated_img = None | |
| if os.path.exists(img_boxes_path): | |
| annotated_img = Image.open(img_boxes_path) | |
| annotated_img.load() | |
| text_out = plain_text_result if plain_text_result else markdown_content | |
| return annotated_img, markdown_content, text_out | |
| # ========================= | |
| # Estados / helpers para la UI | |
| # ========================= | |
| def ocr_snapshot(md_text: str, plain_text: str): | |
| """ | |
| Guardamos el OCR en estados (para enviarlo al chat después) | |
| y devolvemos esas vistas rápidas. | |
| """ | |
| return md_text, plain_text, md_text, plain_text | |
| def chat_reply(user_msg, chat_state, ocr_md_state, ocr_txt_state): | |
| """ | |
| Lógica del botón "Enviar" en el chat. | |
| """ | |
| try: | |
| answer = txagent_chat_remote( | |
| ocr_md_state or "", | |
| ocr_txt_state or "", | |
| user_msg or "" | |
| ) | |
| updated = (chat_state or []) + [ | |
| {"role": "user", "content": user_msg or "(solo OCR)"}, | |
| {"role": "assistant", "content": answer}, | |
| ] | |
| return updated, "", "" | |
| except Exception as e: | |
| tb = traceback.format_exc(limit=2) | |
| updated = (chat_state or []) + [ | |
| {"role": "user", "content": user_msg or ""}, | |
| { | |
| "role": "assistant", | |
| "content": f"⚠️ Error remoto (chat): {e.__class__.__name__}: {e}", | |
| }, | |
| ] | |
| return updated, "", f"{e}\n{tb}" | |
| def clear_chat(): | |
| return [], "", "" | |
| # ========================= | |
| # UI en Gradio 5 | |
| # ========================= | |
| with gr.Blocks( | |
| title="OpScanIA", | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # 📄 DeepSeek-OCR → 💬 Chat Clínico | |
| 1. **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto). | |
| 2. El chat usa automáticamente el texto detectado por OCR | |
| como contexto clínico. | |
| ⚠ Uso educativo. No reemplaza consejo médico profesional. | |
| """ | |
| ) | |
| # Estados para pasar OCR -> Chat | |
| ocr_md_state = gr.State("") | |
| ocr_txt_state = gr.State("") | |
| with gr.Row(): | |
| # Panel OCR | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload Image", | |
| sources=["upload", "clipboard", "webcam"] | |
| ) | |
| model_size = gr.Dropdown( | |
| choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], | |
| value="Gundam (Recommended)", | |
| label="OCR Model Size" | |
| ) | |
| task_type = gr.Dropdown( | |
| choices=["Free OCR", "Convert to Markdown"], | |
| value="Convert to Markdown", | |
| label="OCR Task" | |
| ) | |
| eval_mode_checkbox = gr.Checkbox( | |
| value=True, | |
| label="Evaluation mode (más rápido)", | |
| info="Puede omitir imagen anotada y concentrarse en el texto." | |
| ) | |
| submit_btn = gr.Button("Process Image", variant="primary") | |
| # Resultados OCR | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.TabItem("Annotated Image"): | |
| output_image = gr.Image(interactive=False) | |
| with gr.TabItem("Markdown Preview"): | |
| output_markdown = gr.Markdown() | |
| with gr.TabItem("Markdown / OCR Text"): | |
| output_text = gr.Textbox( | |
| lines=18, | |
| show_copy_button=True, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| md_preview = gr.Textbox( | |
| label="Snapshot Markdown OCR", | |
| lines=8, | |
| interactive=False | |
| ) | |
| txt_preview = gr.Textbox( | |
| label="Snapshot Texto OCR", | |
| lines=8, | |
| interactive=False | |
| ) | |
| # Panel Chat | |
| gr.Markdown("## Chat Clínico") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="Asistente OCR (TxAgent remoto)", | |
| type="messages", | |
| height=420, | |
| ) | |
| user_in = gr.Textbox( | |
| label="Mensaje", | |
| placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| send_btn = gr.Button("Enviar", variant="primary") | |
| clear_btn = gr.Button("Limpiar") | |
| with gr.Column(scale=1): | |
| error_box = gr.Textbox( | |
| label="Debug (si hay error)", | |
| lines=8, | |
| interactive=False | |
| ) | |
| # Wiring OCR | |
| submit_btn.click( | |
| fn=ocr_infer, | |
| inputs=[image_input, model_size, task_type, eval_mode_checkbox], | |
| outputs=[output_image, output_markdown, output_text], | |
| ).then( | |
| fn=ocr_snapshot, | |
| inputs=[output_markdown, output_text], | |
| outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview], | |
| ) | |
| # Wiring Chat | |
| send_btn.click( | |
| fn=chat_reply, | |
| inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state], | |
| outputs=[chatbot, user_in, error_box], | |
| ) | |
| clear_btn.click( | |
| fn=clear_chat, | |
| outputs=[chatbot, user_in, error_box], | |
| ) | |
| if __name__ == "__main__": | |
| # Gradio 5: sin concurrency_count en queue() | |
| # demo.queue(max_size=32) # opcional si quieres limitar cola | |
| demo.launch() | |