Spaces:
Sleeping
Sleeping
| # app.py — DeepSeek-OCR + DeepSeek-R1 Medical Mini (remoto HF o local GGUF) — Gradio 5 | |
| import os, tempfile, traceback | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModel, AutoTokenizer | |
| import spaces | |
| from huggingface_hub import hf_hub_download, InferenceClient | |
| from llama_cpp import Llama | |
| # =============================================================== | |
| # Configuración LLM (CHAT) — DeepSeek-R1 Medical Mini | |
| # - Remoto (HF Inference): R1_REMOTE=1 y (opcional) R1_MODEL_ID, HF_TOKEN | |
| # - Local GGUF (CPU/Zero): R1_REMOTE=0 y GGUF_REPO / GGUF_FILE | |
| # =============================================================== | |
| R1_REMOTE = os.getenv("R1_REMOTE", "0") == "1" | |
| R1_MODEL_ID = os.getenv("R1_MODEL_ID", "Mouhib007/DeepSeek-r1-Medical-Mini") | |
| HF_TOKEN = os.getenv("HF_TOKEN") # público -> puede ser None | |
| # ---- Local GGUF (fallback / modo offline) ---- | |
| GGUF_CANDIDATES = [] | |
| ENV_REPO = os.getenv("GGUF_REPO", "").strip() | |
| ENV_FILE = os.getenv("GGUF_FILE", "").strip() | |
| if ENV_REPO and ENV_FILE: | |
| GGUF_CANDIDATES.append((ENV_REPO, ENV_FILE)) | |
| # Candidato por defecto (ajústalo si usas otro) | |
| GGUF_CANDIDATES.append(( | |
| "mradermacher/DeepSeek-r1-Medical-Mini-GGUF", | |
| "DeepSeek-r1-Medical-Mini.f16.gguf" | |
| )) | |
| N_CTX = int(os.getenv("N_CTX", "2048")) | |
| N_THREADS = int(os.getenv("N_THREADS", str(os.cpu_count() or 4))) | |
| N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0")) | |
| N_BATCH = int(os.getenv("N_BATCH", "96")) | |
| # ---- Cliente remoto (HF Inference) ---- | |
| _remote_client = None | |
| def get_remote_client(): | |
| global _remote_client | |
| if _remote_client is None: | |
| _remote_client = InferenceClient(model=R1_MODEL_ID, token=HF_TOKEN, timeout=60) | |
| return _remote_client | |
| # ---- Formato ChatML (compatible con DeepSeek/Qwen) ---- | |
| def _format_chatml(messages): | |
| parts = [] | |
| for m in messages: | |
| role = m.get("role", "user") | |
| content = m.get("content", "") | |
| parts.append(f"<|im_start|>{role}\n{content}<|im_end|>\n") | |
| parts.append("<|im_start|>assistant\n") | |
| return "".join(parts) | |
| def r1_chat(messages, temperature=0.2, max_tokens=384): | |
| """Remoto (HF) o local (llama-cpp) para DeepSeek-R1 Medical Mini.""" | |
| if R1_REMOTE: | |
| client = get_remote_client() | |
| try: | |
| # Algunos endpoints soportan chat_completion | |
| resp = client.chat_completion(messages=messages, temperature=temperature, max_tokens=max_tokens) | |
| return resp.choices[0].message["content"] | |
| except Exception: | |
| # Fallback universal a text_generation con ChatML | |
| try: | |
| prompt = _format_chatml(messages) | |
| return client.text_generation( | |
| prompt, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| stop_sequences=["<|im_end|>"], | |
| stream=False, | |
| ) | |
| except Exception: | |
| # Si remoto falla (401/429/etc), caemos a local si hay GGUF | |
| pass | |
| # Local GGUF | |
| llm = get_llm() | |
| out = llm.create_chat_completion(messages=messages, temperature=temperature, max_tokens=max_tokens) | |
| return out["choices"][0]["message"]["content"] | |
| # ---- Loader local (GGUF) ---- | |
| _llm = None | |
| def _download_gguf(): | |
| last_err = None | |
| for repo, fname in GGUF_CANDIDATES: | |
| try: | |
| return hf_hub_download(repo_id=repo, filename=fname), repo, fname | |
| except Exception as e: | |
| last_err = e | |
| raise RuntimeError(f"No se pudo descargar ningún GGUF. Último error: {last_err}") | |
| def get_llm(): | |
| global _llm | |
| if _llm is not None: | |
| return _llm | |
| gguf_path, _, _ = _download_gguf() | |
| _llm = Llama( | |
| model_path=gguf_path, | |
| # No forzamos chat_format; usamos el del GGUF del R1 | |
| n_ctx=N_CTX, | |
| n_threads=N_THREADS, | |
| n_gpu_layers=N_GPU_LAYERS, | |
| n_batch=N_BATCH, | |
| verbose=False, | |
| ) | |
| return _llm | |
| # Warmup opcional (para no esperar en el primer mensaje si usas local) | |
| if os.getenv("WARMUP", "0") == "1" and not R1_REMOTE: | |
| try: | |
| get_llm() | |
| except Exception: | |
| pass | |
| # =============================================================== | |
| # DeepSeek-OCR (INTACTO — con fallback si no hay FlashAttention2) | |
| # =============================================================== | |
| def _best_dtype(): | |
| if torch.cuda.is_available(): | |
| return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| return torch.float32 | |
| def _load_ocr_model(): | |
| model_name = "deepseek-ai/DeepSeek-OCR" | |
| ocr_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2") # por defecto igual que antes | |
| try: | |
| ocr_model = AutoModel.from_pretrained( | |
| model_name, | |
| _attn_implementation=attn_impl, | |
| trust_remote_code=True, | |
| use_safetensors=True, | |
| ).eval() | |
| return ocr_tokenizer, ocr_model | |
| except Exception as e: | |
| # Si falla por FlashAttention2, reintenta en modo "eager" (CPU/compat) | |
| msg = str(e) | |
| if "flash_attn" in msg or "FlashAttention2" in msg or "flash_attention_2" in msg: | |
| ocr_model = AutoModel.from_pretrained( | |
| model_name, | |
| _attn_implementation="eager", | |
| trust_remote_code=True, | |
| use_safetensors=True, | |
| ).eval() | |
| return ocr_tokenizer, ocr_model | |
| raise | |
| tokenizer, model = _load_ocr_model() | |
| def process_image(image, model_size, task_type, is_eval_mode): | |
| """ | |
| Devuelve: imagen anotada, markdown y texto (o markdown si no hay texto). | |
| """ | |
| if image is None: | |
| return None, "Please upload an image first.", "Please upload an image first." | |
| dtype = _best_dtype() | |
| model_device = model.cuda().to(dtype) if torch.cuda.is_available() else model.to(dtype) | |
| with tempfile.TemporaryDirectory() as output_path: | |
| if task_type == "Free OCR": | |
| prompt = "<image>\nFree OCR. " | |
| elif task_type == "Convert to Markdown": | |
| prompt = "<image>\n<|grounding|>Convert the document to markdown. " | |
| else: | |
| prompt = "<image>\nFree OCR. " | |
| temp_image_path = os.path.join(output_path, "temp_image.jpg") | |
| image.save(temp_image_path) | |
| size_configs = { | |
| "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, | |
| "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, | |
| "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, | |
| "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, | |
| "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True}, | |
| } | |
| config = size_configs.get(model_size, size_configs["Gundam (Recommended)"]) | |
| plain_text_result = model_device.infer( | |
| tokenizer, | |
| prompt=prompt, | |
| image_file=temp_image_path, | |
| output_path=output_path, | |
| base_size=config["base_size"], | |
| image_size=config["image_size"], | |
| crop_mode=config["crop_mode"], | |
| save_results=True, | |
| test_compress=True, | |
| eval_mode=is_eval_mode, | |
| ) | |
| image_result_path = os.path.join(output_path, "result_with_boxes.jpg") | |
| markdown_result_path = os.path.join(output_path, "result.mmd") | |
| if os.path.exists(markdown_result_path): | |
| with open(markdown_result_path, "r", encoding="utf-8") as f: | |
| markdown_content = f.read() | |
| else: | |
| markdown_content = "Markdown result was not generated. This is expected for 'Free OCR' task." | |
| result_image = None | |
| if os.path.exists(image_result_path): | |
| result_image = Image.open(image_result_path) | |
| result_image.load() | |
| text_result = plain_text_result if plain_text_result else markdown_content | |
| return result_image, markdown_content, text_result | |
| # =============================================================== | |
| # Chat (inyecta OCR en el primer system) — usando R1 | |
| # =============================================================== | |
| def _truncate(text, max_chars=3000): | |
| return (text or "")[:max_chars] | |
| def _system_prompt(): | |
| return ( | |
| "Eres un asistente clínico educativo. No sustituyes el juicio médico. " | |
| "Usa CONTEXTO_OCR si existe; si falta, pídelo. Evita diagnósticos definitivos." | |
| ) | |
| def _ocr_context(ocr_md, ocr_txt): | |
| return _truncate(ocr_md) or _truncate(ocr_txt) or "" | |
| def to_chat_messages(chat_msgs, ocr_md, ocr_txt): | |
| sys = _system_prompt() | |
| ctx = _ocr_context(ocr_md, ocr_txt) | |
| if ctx: | |
| sys += ( | |
| "\n\n---\n" | |
| "CONTEXTO_OCR (fuente principal; si falta un dato, dilo explícitamente):\n" | |
| f"{ctx}\n---" | |
| ) | |
| msgs = [{"role": "system", "content": sys}] | |
| for m in (chat_msgs or []): | |
| if m.get("role") in ("user", "assistant"): | |
| msgs.append({"role": m["role"], "content": m.get("content", "")}) | |
| return msgs | |
| def r1_reply(user_msg, chat_msgs, ocr_md, ocr_txt): | |
| if not user_msg: | |
| user_msg = "Analiza el CONTEXTO_OCR anterior y responde a partir de ese contenido." | |
| try: | |
| msgs = to_chat_messages(chat_msgs, ocr_md, ocr_txt) + [{"role": "user", "content": user_msg}] | |
| answer = r1_chat(msgs, temperature=0.2, max_tokens=512) | |
| updated = (chat_msgs or []) + [ | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": answer}, | |
| ] | |
| return updated, "", gr.update(value="") | |
| except Exception as e: | |
| err = f"{e.__class__.__name__}: {str(e) or repr(e)}" | |
| tb = traceback.format_exc(limit=2) | |
| updated = (chat_msgs or []) + [ | |
| {"role": "user", "content": user_msg or ""}, | |
| {"role": "assistant", "content": f"⚠️ Error LLM: {err}"}, | |
| ] | |
| return updated, "", gr.update(value=f"{err}\n{tb}") | |
| def clear_chat(): | |
| return [], "", gr.update(value="") | |
| # =============================================================== | |
| # UI (Gradio 5) | |
| # =============================================================== | |
| with gr.Blocks(title="DeepSeek-OCR + DeepSeek-R1 Medical Mini", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # DeepSeek-OCR → Chat Médico con **DeepSeek-R1 Medical Mini** (remoto HF o local GGUF) | |
| 1) **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto). | |
| 2) **Chatea** con **DeepSeek-R1 Medical Mini** usando automáticamente el **OCR** como contexto. | |
| *Uso educativo; no reemplaza consejo médico.* | |
| """ | |
| ) | |
| ocr_md_state = gr.State("") | |
| ocr_txt_state = gr.State("") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard", "webcam"]) | |
| model_size = gr.Dropdown( | |
| choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], | |
| value="Gundam (Recommended)", label="Model Size", | |
| ) | |
| task_type = gr.Dropdown( | |
| choices=["Free OCR", "Convert to Markdown"], | |
| value="Convert to Markdown", label="Task Type", | |
| ) | |
| eval_mode_checkbox = gr.Checkbox( | |
| value=False, label="Enable Evaluation Mode", | |
| info="Solo texto (más rápido). Desmárcalo para ver imagen anotada y markdown.", | |
| ) | |
| submit_btn = gr.Button("Process Image", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.TabItem("Annotated Image"): output_image = gr.Image(interactive=False) | |
| with gr.TabItem("Markdown Preview"): output_markdown = gr.Markdown() | |
| with gr.TabItem("Markdown Source (or Eval Output)"): | |
| output_text = gr.Textbox(lines=18, show_copy_button=True, interactive=False) | |
| with gr.Row(): | |
| md_preview = gr.Textbox(label="Snapshot Markdown OCR", lines=10, interactive=False) | |
| txt_preview = gr.Textbox(label="Snapshot Texto OCR", lines=10, interactive=False) | |
| gr.Markdown("## Chat Clínico (DeepSeek-R1 Medical Mini)") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="Asistente OCR (R1 Medical Mini)", type="messages", height=420) | |
| user_in = gr.Textbox(label="Mensaje", placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)", lines=2) | |
| with gr.Row(): | |
| send_btn = gr.Button("Enviar", variant="primary") | |
| clear_btn = gr.Button("Limpiar") | |
| with gr.Column(scale=1): | |
| error_box = gr.Textbox(label="Debug (si hay error)", lines=8, interactive=False) | |
| # OCR → outputs y estados | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[image_input, model_size, task_type, eval_mode_checkbox], | |
| outputs=[output_image, output_markdown, output_text], | |
| ).then( | |
| fn=lambda md, tx: (md, tx, md, tx), | |
| inputs=[output_markdown, output_text], | |
| outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview], | |
| ) | |
| # Chat | |
| send_btn.click( | |
| fn=r1_reply, | |
| inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state], | |
| outputs=[chatbot, user_in, error_box], | |
| ) | |
| clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, error_box]) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20) | |
| demo.launch() | |