jorgeiv500 commited on
Commit
7cb8c04
·
verified ·
1 Parent(s): 1cb9d27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -268
app.py CHANGED
@@ -1,9 +1,9 @@
1
- # app.py — DeepSeek-OCR + BioMedLM (text_generation remoto + ZeroGPU-safe local) — Gradio 5
2
- import os, tempfile, traceback, json
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
- from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
7
  import spaces
8
  from huggingface_hub import InferenceClient
9
  import requests
@@ -11,119 +11,95 @@ import requests
11
  # =========================
12
  # CONFIG (env)
13
  # =========================
14
- BIO_REMOTE = os.getenv("BIO_REMOTE", "1") == "1" # recomendado en Spaces ZeroGPU
15
- BIO_MODEL_ID = os.getenv("BIO_MODEL_ID", "stanford-crfm/BioMedLM").strip()
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
 
18
- # Fallbacks
19
- BIO_FALLBACK_HTTP = os.getenv("BIO_FALLBACK_HTTP", "1") == "1" # si InferenceClient falla => router HTTP
20
- BIO_FALLBACK_LOCAL = os.getenv("BIO_FALLBACK_LOCAL", "1") == "1" # si todo remoto falla => intenta local GPU
21
-
22
- # Parámetros de generación
23
- GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
24
- GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
25
- GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
26
- GEN_REP_PENALTY = float(os.getenv("GEN_REP_PENALTY", "1.1"))
27
  GEN_TIMEOUT = int(os.getenv("GEN_TIMEOUT", "60")) # s
 
28
 
29
- STOP_SEQS = ["\nUser:", "### System", "### Context", "### Conversation"]
30
-
31
- # Caches (sin tocar CUDA en el proceso principal)
32
- _hf_client = None
33
- _bio_local_cache = {"model": None, "tokenizer": None}
34
 
35
  # =========================
36
  # Prompt helpers
37
  # =========================
38
- def _truncate(text, max_chars=3000): return (text or "")[:max_chars]
39
-
40
- def _system_prompt():
41
- return ("Eres un asistente clínico educativo. No sustituyes el juicio médico. "
42
- "Usa CONTEXTO_OCR si existe; si falta, pídelo. Evita diagnósticos definitivos.")
43
-
44
- def _ocr_context(ocr_md, ocr_txt): return _truncate(ocr_md) or _truncate(ocr_txt) or ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg):
47
- sys = _system_prompt()
48
- ctx = _ocr_context(ocr_md, ocr_txt)
49
 
50
- history_lines = []
51
  for m in (chat_msgs or []):
52
- role = m.get("role")
53
- content = (m.get("content") or "").strip()
54
- if not content:
55
- continue
56
- if role == "user":
57
- history_lines.append(f"User: {content}")
58
- elif role == "assistant":
59
- history_lines.append(f"Assistant: {content}")
60
-
61
- if user_msg:
62
- history_lines.append(f"User: {user_msg}")
63
-
64
- convo = "\n".join(history_lines).strip()
65
- prompt = f"### System\n{sys}\n\n"
66
- if ctx:
67
- prompt += f"### Context (OCR)\n{ctx}\n\n"
68
- prompt += f"### Conversation\n{convo}\nAssistant:"
69
  return prompt
70
 
71
  # =========================
72
- # BioMedLM remoto/local (NO CUDA en main)
73
  # =========================
74
- def get_biomedlm():
75
- """Decidir modo. No tocar CUDA aquí."""
76
- global _hf_client
77
- if BIO_REMOTE:
78
- if _hf_client is None:
79
- # timeout va en el constructor (no en la llamada)
80
- _hf_client = InferenceClient(
81
- model=BIO_MODEL_ID,
82
- token=HF_TOKEN,
83
- timeout=GEN_TIMEOUT,
84
- )
85
- return ("remote", _hf_client)
86
- return ("local", None)
87
-
88
- def _hf_http_completions(prompt: str) -> str:
89
- """Fallback HTTP al router HF (OpenAI-like /v1/completions)."""
90
- headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
91
- payload = {
92
- "model": BIO_MODEL_ID,
93
- "prompt": prompt,
94
- "max_tokens": GEN_MAX_NEW_TOKENS,
95
- "temperature": GEN_TEMPERATURE,
96
- "top_p": GEN_TOP_P,
97
- "stop": STOP_SEQS,
98
- }
99
- urls = [
100
- "https://router.huggingface.co/v1/completions",
101
- "https://router.huggingface.co/hf-inference/v1/completions",
102
- ]
103
- last_exc = None
104
- for url in urls:
105
- try:
106
- r = requests.post(url, headers=headers, json=payload, timeout=GEN_TIMEOUT)
107
- if r.status_code == 200:
108
- data = r.json()
109
- # OpenAI completions-like
110
- if isinstance(data, dict) and "choices" in data and data["choices"]:
111
- return (data["choices"][0].get("text") or "").strip()
112
- return json.dumps(data)[:4000]
113
- last_exc = RuntimeError(f"HTTP {r.status_code}: {r.text[:800]}")
114
- except Exception as e:
115
- last_exc = e
116
- raise last_exc or RuntimeError("HF router completions error")
117
-
118
- def call_biomedlm_remote(prompt: str) -> (str, str):
119
  """
120
- Usa InferenceClient.text_generation (task soportada por BioMedLM).
121
- Si falla, cae a HTTP router /v1/completions.
122
- Retorna (respuesta, debug_msg)
123
  """
124
- client = get_biomedlm()[1]
125
  try:
126
- out = client.text_generation(
127
  prompt=prompt,
128
  max_new_tokens=GEN_MAX_NEW_TOKENS,
129
  temperature=GEN_TEMPERATURE,
@@ -131,151 +107,76 @@ def call_biomedlm_remote(prompt: str) -> (str, str):
131
  repetition_penalty=GEN_REP_PENALTY,
132
  stop_sequences=STOP_SEQS,
133
  details=False,
 
134
  stream=False,
135
  )
136
- # huggingface_hub devuelve str si details=False
137
- answer = out.strip() if isinstance(out, str) else str(out)
138
- return answer, ""
139
- except Exception as e:
140
- if not BIO_FALLBACK_HTTP:
141
- raise
142
- # Fallback HTTP al router nuevo (completions)
143
  try:
144
- answer = _hf_http_completions(prompt)
145
- return answer, f"[Fallback HTTP router/completions] {e.__class__.__name__}: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  except Exception as e2:
147
- raise RuntimeError(
148
- f"Remote generation failed: {e.__class__.__name__}: {e} | HTTP fallback: {e2.__class__.__name__}: {e2}"
149
- )
150
-
151
- @spaces.GPU
152
- def biomedlm_infer_local(prompt: str,
153
- temperature=0.2,
154
- top_p=0.9,
155
- rep_penalty=1.1,
156
- max_new_tokens=512) -> str:
157
- """Ejecución local en worker GPU; devuelve OK:: o ERR::..."""
158
- try:
159
- if _bio_local_cache["model"] is None:
160
- tok = AutoTokenizer.from_pretrained(BIO_MODEL_ID, use_fast=True)
161
- dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else (
162
- torch.float16 if torch.cuda.is_available() else torch.float32
163
- )
164
- model = AutoModelForCausalLM.from_pretrained(BIO_MODEL_ID, torch_dtype=dtype)
165
- if torch.cuda.is_available():
166
- model = model.to("cuda")
167
- _bio_local_cache["model"] = model.eval()
168
- _bio_local_cache["tokenizer"] = tok
169
-
170
- model = _bio_local_cache["model"]
171
- tok = _bio_local_cache["tokenizer"]
172
- inputs = tok(prompt, return_tensors="pt")
173
- if torch.cuda.is_available():
174
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
175
-
176
- gen_ids = model.generate(
177
- **inputs,
178
- do_sample=True,
179
- temperature=temperature,
180
- top_p=top_p,
181
- repetition_penalty=rep_penalty,
182
- max_new_tokens=max_new_tokens,
183
- eos_token_id=tok.eos_token_id,
184
- )
185
- text = tok.decode(gen_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
186
- return "OK::" + text.strip()
187
- except Exception as e:
188
- return f"ERR::[{e.__class__.__name__}] {str(e) or repr(e)}"
189
 
190
- def biomedlm_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
191
  try:
192
- if not user_msg:
193
- user_msg = "Analiza el CONTEXTO_OCR anterior y responde a partir de ese contenido."
194
  prompt = build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg)
195
-
196
- mode, _ = get_biomedlm()
197
- if mode == "remote":
198
- try:
199
- answer, dbg = call_biomedlm_remote(prompt)
200
- updated = (chat_msgs or []) + [
201
- {"role": "user", "content": user_msg},
202
- {"role": "assistant", "content": answer}
203
- ]
204
- return updated, "", gr.update(value=dbg)
205
- except Exception as e_remote:
206
- if not BIO_FALLBACK_LOCAL:
207
- raise
208
- # Fallback a local si remoto no disponible
209
- res = biomedlm_infer_local(
210
- prompt,
211
- temperature=GEN_TEMPERATURE,
212
- top_p=GEN_TOP_P,
213
- rep_penalty=GEN_REP_PENALTY,
214
- max_new_tokens=GEN_MAX_NEW_TOKENS
215
- )
216
- if res.startswith("OK::"):
217
- answer = res[4:]
218
- updated = (chat_msgs or []) + [
219
- {"role": "user", "content": user_msg},
220
- {"role": "assistant", "content": answer}
221
- ]
222
- return updated, "", gr.update(value=f"[Remoto→Local] {e_remote}")
223
- else:
224
- err_msg = res[5:] if res.startswith("ERR::") else res
225
- raise RuntimeError(f"Remote error: {e_remote} | Local error: {err_msg}")
226
-
227
- # Modo local explícito
228
- res = biomedlm_infer_local(
229
- prompt,
230
- temperature=GEN_TEMPERATURE,
231
- top_p=GEN_TOP_P,
232
- rep_penalty=GEN_REP_PENALTY,
233
- max_new_tokens=GEN_MAX_NEW_TOKENS
234
- )
235
- if res.startswith("OK::"):
236
- answer = res[4:]
237
- updated = (chat_msgs or []) + [
238
- {"role": "user", "content": user_msg},
239
- {"role": "assistant", "content": answer}
240
- ]
241
- return updated, "", gr.update(value="")
242
- else:
243
- err_msg = res[5:] if res.startswith("ERR::") else res
244
- updated = (chat_msgs or []) + [
245
- {"role": "user", "content": user_msg},
246
- {"role": "assistant", "content": "⚠️ Error LLM (local). Revisa el panel de debug."}
247
- ]
248
- return updated, "", gr.update(value=err_msg)
249
-
250
  except Exception as e:
251
- err = f"{e.__class__.__name__}: {str(e) or repr(e)}"
252
  tb = traceback.format_exc(limit=2)
253
  updated = (chat_msgs or []) + [
254
  {"role": "user", "content": user_msg or ""},
255
- {"role": "assistant", "content": f"⚠️ Error LLM: {err}"}
256
  ]
257
- return updated, "", gr.update(value=f"{err}\n{tb}")
258
 
259
  def clear_chat(): return [], "", gr.update(value="")
260
 
261
  # =========================
262
- # DeepSeek-OCR (sin CUDA en main)
263
  # =========================
264
  def _load_ocr_model():
265
  model_name = "deepseek-ai/DeepSeek-OCR"
266
- ocr_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
267
  attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")
268
  try:
269
- ocr_model = AutoModel.from_pretrained(
270
- model_name, _attn_implementation=attn_impl, trust_remote_code=True, use_safetensors=True
 
 
 
271
  ).eval()
272
- return ocr_tokenizer, ocr_model
273
  except Exception as e:
274
  if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
275
- ocr_model = AutoModel.from_pretrained(
276
- model_name, _attn_implementation="eager", trust_remote_code=True, use_safetensors=True
 
 
 
277
  ).eval()
278
- return ocr_tokenizer, ocr_model
279
  raise
280
 
281
  tokenizer, model = _load_ocr_model()
@@ -285,6 +186,7 @@ def process_image(image, model_size, task_type, is_eval_mode):
285
  if image is None:
286
  return None, "Please upload an image first.", "Please upload an image first."
287
 
 
288
  if torch.cuda.is_available():
289
  dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
290
  model_device = model.to(dtype).to("cuda")
@@ -297,23 +199,23 @@ def process_image(image, model_size, task_type, is_eval_mode):
297
  temp_image_path = os.path.join(output_path, "temp_image.jpg")
298
  image.save(temp_image_path)
299
 
300
- size_configs = {
301
- "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
302
- "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
303
- "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
304
- "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
305
- "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True},
306
  }
307
- config = size_configs.get(model_size, size_configs["Gundam (Recommended)"])
308
 
309
- plain_text_result = model_device.infer(
310
  tokenizer,
311
  prompt=prompt,
312
  image_file=temp_image_path,
313
  output_path=output_path,
314
- base_size=config["base_size"],
315
- image_size=config["image_size"],
316
- crop_mode=config["crop_mode"],
317
  save_results=True,
318
  test_compress=True,
319
  eval_mode=is_eval_mode,
@@ -331,18 +233,18 @@ def process_image(image, model_size, task_type, is_eval_mode):
331
  if os.path.exists(image_result_path):
332
  result_image = Image.open(image_result_path); result_image.load()
333
 
334
- text_result = plain_text_result if plain_text_result else markdown_content
335
  return result_image, markdown_content, text_result
336
 
337
  # =========================
338
  # UI (Gradio 5)
339
  # =========================
340
- with gr.Blocks(title="DeepSeek-OCR + BioMedLM", theme=gr.themes.Soft()) as demo:
341
  gr.Markdown(
342
  """
343
- # DeepSeek-OCR → Chat Médico con **BioMedLM**
344
  1) **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto).
345
- 2) **Chatea** con **BioMedLM** usando automáticamente el **OCR** como contexto.
346
  *Uso educativo; no reemplaza consejo médico.*
347
  """
348
  )
@@ -353,10 +255,14 @@ with gr.Blocks(title="DeepSeek-OCR + BioMedLM", theme=gr.themes.Soft()) as demo:
353
  with gr.Row():
354
  with gr.Column(scale=1):
355
  image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard", "webcam"])
356
- model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
357
- value="Gundam (Recommended)", label="Model Size")
358
- task_type = gr.Dropdown(choices=["Free OCR", "Convert to Markdown"],
359
- value="Convert to Markdown", label="Task Type")
 
 
 
 
360
  eval_mode_checkbox = gr.Checkbox(value=False, label="Enable Evaluation Mode",
361
  info="Solo texto (más rápido). Desmárcalo para ver imagen anotada y markdown.")
362
  submit_btn = gr.Button("Process Image", variant="primary")
@@ -365,37 +271,7 @@ with gr.Blocks(title="DeepSeek-OCR + BioMedLM", theme=gr.themes.Soft()) as demo:
365
  with gr.Tabs():
366
  with gr.TabItem("Annotated Image"): output_image = gr.Image(interactive=False)
367
  with gr.TabItem("Markdown Preview"): output_markdown = gr.Markdown()
368
- with gr.TabItem("Markdown Source (or Eval Output)"):
369
- output_text = gr.Textbox(lines=18, show_copy_button=True, interactive=False)
370
  with gr.Row():
371
  md_preview = gr.Textbox(label="Snapshot Markdown OCR", lines=10, interactive=False)
372
- txt_preview = gr.Textbox(label="Snapshot Texto OCR", lines=10, interactive=False)
373
-
374
- gr.Markdown("## Chat Clínico (BioMedLM)")
375
- with gr.Row():
376
- with gr.Column(scale=2):
377
- chatbot = gr.Chatbot(label="Asistente OCR (BioMedLM)", type="messages", height=420)
378
- user_in = gr.Textbox(label="Mensaje", placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)", lines=2)
379
- with gr.Row():
380
- send_btn = gr.Button("Enviar", variant="primary")
381
- clear_btn = gr.Button("Limpiar")
382
- with gr.Column(scale=1):
383
- error_box = gr.Textbox(label="Debug (si hay error)", lines=8, interactive=False)
384
-
385
- submit_btn.click(
386
- fn=process_image,
387
- inputs=[image_input, model_size, task_type, eval_mode_checkbox],
388
- outputs=[output_image, output_markdown, output_text],
389
- ).then(
390
- fn=lambda md, tx: (md, tx, md, tx),
391
- inputs=[output_markdown, output_text],
392
- outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
393
- )
394
-
395
- send_btn.click(fn=biomedlm_reply, inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
396
- outputs=[chatbot, user_in, error_box])
397
- clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, error_box])
398
-
399
- if __name__ == "__main__":
400
- demo.queue(max_size=20)
401
- demo.launch()
 
1
+ # app.py — DeepSeek-OCR + Med42 Instruct (remoto, ZeroGPU-safe) — Gradio 5
2
+ import os, re, json, tempfile, traceback
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
+ from transformers import AutoModel, AutoTokenizer
7
  import spaces
8
  from huggingface_hub import InferenceClient
9
  import requests
 
11
  # =========================
12
  # CONFIG (env)
13
  # =========================
14
+ LLM_MODEL_ID = os.getenv("BIO_MODEL_ID", "m42-health/Llama3-Med42-8B-Instruct").strip()
 
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
 
17
+ # Generación (determinista para obediencia)
18
+ GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.0"))
19
+ GEN_TOP_P = float(os.getenv("GEN_TOP_P", "1.0"))
20
+ GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "384"))
21
+ GEN_REP_PENALTY = float(os.getenv("GEN_REP_PENALTY", "1.0"))
 
 
 
 
22
  GEN_TIMEOUT = int(os.getenv("GEN_TIMEOUT", "60")) # s
23
+ STOP_SEQS = ["\n###", "\nUser:", "\nAssistant:"]
24
 
25
+ # Cliente remoto (HTTP) no toca CUDA
26
+ _hf_client = InferenceClient(model=LLM_MODEL_ID, token=HF_TOKEN, timeout=GEN_TIMEOUT)
 
 
 
27
 
28
  # =========================
29
  # Prompt helpers
30
  # =========================
31
+ def _truncate(s: str, n=3000): return (s or "")[:n]
32
+
33
+ def _clean_ocr(s: str) -> str:
34
+ if not s: return ""
35
+ s = re.sub(r'[^\S\r\n]+', ' ', s) # colapsa espacios
36
+ s = re.sub(r'(\{#Sec\d+\}|#+\w*)', ' ', s) # anchors/headers raros
37
+ s = re.sub(r'\s{2,}', ' ', s)
38
+ lines = []
39
+ for par in s.splitlines():
40
+ par = par.strip()
41
+ if 0 < len(par) <= 600:
42
+ lines.append(par)
43
+ return "\n".join(lines)
44
+
45
+ FEWSHOT = """
46
+ ### INSTRUCCIÓN
47
+ Eres un **analista clínico educativo**. Responde **SIEMPRE en español**.
48
+ Reglas: (1) Usa ÚNICAMENTE el CONTEXTO_OCR; (2) Si falta un dato, escribe literalmente: "dato no disponible en el OCR";
49
+ (3) No inventes nada; (4) Responde en viñetas claras; (5) Cita fragmentos exactos del OCR entre comillas como evidencia.
50
+
51
+ ### EJEMPLO 1
52
+ CONTEXTO_OCR:
53
+ Paciente: Juan Pérez. Medicamento: Amoxicilina 500 mg cada 8 horas por 7 días.
54
+ PREGUNTA:
55
+ ¿Cuál es el medicamento y la dosis?
56
+ SALIDA_ES:
57
+ - Medicamento: **Amoxicilina**
58
+ - Dosis: **500 mg cada 8 horas por 7 días**
59
+ - Evidencia OCR: "Amoxicilina 500 mg cada 8 horas por 7 días"
60
+
61
+ ### EJEMPLO 2
62
+ CONTEXTO_OCR:
63
+ Paciente: —. Indicaciones ilegibles.
64
+ PREGUNTA:
65
+ ¿Hay contraindicaciones registradas?
66
+ SALIDA_ES:
67
+ - Contraindicaciones: **dato no disponible en el OCR**
68
+ - Evidencia OCR: "Indicaciones ilegibles"
69
+ """.strip()
70
 
71
  def build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg):
72
+ raw = ocr_md if (ocr_md and ocr_md.strip()) else ocr_txt
73
+ ctx = _truncate(_clean_ocr(raw), 3000)
74
 
75
+ history = []
76
  for m in (chat_msgs or []):
77
+ role, content = m.get("role"), (m.get("content") or "").strip()
78
+ if not content: continue
79
+ history.append(f"- { 'Usuario' if role=='user' else 'Asistente' }: {content}")
80
+ hist_block = "\n".join(history) if history else "—"
81
+
82
+ question = (user_msg or "Analiza el CONTEXTO_OCR y resume lo clínicamente relevante en viñetas.").strip()
83
+
84
+ prompt = (
85
+ FEWSHOT + "\n\n"
86
+ "### CONTEXTO_OCR\n" + (ctx if ctx else "—") + "\n\n"
87
+ "### HISTORIAL (si existe)\n" + hist_block + "\n\n"
88
+ "### PREGUNTA\n" + question + "\n\n"
89
+ "### SALIDA_ES\n"
90
+ )
 
 
 
91
  return prompt
92
 
93
  # =========================
94
+ # LLM remoto (Med42 Instruct) text_generation
95
  # =========================
96
+ def med42_remote_generate(prompt: str) -> (str, str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  """
98
+ Intenta InferenceClient.text_generation (serverless/TGI). Si falla,
99
+ hace fallback al router OpenAI-like /v1/completions.
 
100
  """
 
101
  try:
102
+ out = _hf_client.text_generation(
103
  prompt=prompt,
104
  max_new_tokens=GEN_MAX_NEW_TOKENS,
105
  temperature=GEN_TEMPERATURE,
 
107
  repetition_penalty=GEN_REP_PENALTY,
108
  stop_sequences=STOP_SEQS,
109
  details=False,
110
+ do_sample=False, # determinista
111
  stream=False,
112
  )
113
+ return (out.strip() if isinstance(out, str) else str(out)), ""
114
+ except Exception as e1:
115
+ # Fallback HTTP al router
 
 
 
 
116
  try:
117
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
118
+ payload = {
119
+ "model": LLM_MODEL_ID,
120
+ "prompt": prompt,
121
+ "max_tokens": GEN_MAX_NEW_TOKENS,
122
+ "temperature": GEN_TEMPERATURE,
123
+ "top_p": GEN_TOP_P,
124
+ "stop": STOP_SEQS,
125
+ }
126
+ for url in ["https://router.huggingface.co/v1/completions",
127
+ "https://router.huggingface.co/hf-inference/v1/completions"]:
128
+ r = requests.post(url, headers=headers, json=payload, timeout=GEN_TIMEOUT)
129
+ if r.status_code == 200:
130
+ data = r.json()
131
+ if isinstance(data, dict) and "choices" in data and data["choices"]:
132
+ return (data["choices"][0].get("text") or "").strip(), f"[Fallback router: {url}] {e1}"
133
+ raise RuntimeError(f"HTTP {r.status_code}: {r.text[:800]}")
134
  except Exception as e2:
135
+ raise RuntimeError(f"Remote generation failed: {e1.__class__.__name__}: {e1} | HTTP fallback: {e2.__class__.__name__}: {e2}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ def med42_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
138
  try:
 
 
139
  prompt = build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg)
140
+ answer, dbg = med42_remote_generate(prompt)
141
+ updated = (chat_msgs or []) + [
142
+ {"role": "user", "content": user_msg or "(analizar solo OCR)"},
143
+ {"role": "assistant", "content": answer}
144
+ ]
145
+ return updated, "", gr.update(value=dbg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  except Exception as e:
 
147
  tb = traceback.format_exc(limit=2)
148
  updated = (chat_msgs or []) + [
149
  {"role": "user", "content": user_msg or ""},
150
+ {"role": "assistant", "content": f"⚠️ Error LLM: {e}"}
151
  ]
152
+ return updated, "", gr.update(value=f"{e}\n{tb}")
153
 
154
  def clear_chat(): return [], "", gr.update(value="")
155
 
156
  # =========================
157
+ # DeepSeek-OCR (sin CUDA en main, GPU solo dentro del worker)
158
  # =========================
159
  def _load_ocr_model():
160
  model_name = "deepseek-ai/DeepSeek-OCR"
161
+ tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
162
  attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")
163
  try:
164
+ mdl = AutoModel.from_pretrained(
165
+ model_name,
166
+ _attn_implementation=attn_impl,
167
+ trust_remote_code=True,
168
+ use_safetensors=True
169
  ).eval()
170
+ return tok, mdl
171
  except Exception as e:
172
  if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
173
+ mdl = AutoModel.from_pretrained(
174
+ model_name,
175
+ _attn_implementation="eager",
176
+ trust_remote_code=True,
177
+ use_safetensors=True
178
  ).eval()
179
+ return tok, mdl
180
  raise
181
 
182
  tokenizer, model = _load_ocr_model()
 
186
  if image is None:
187
  return None, "Please upload an image first.", "Please upload an image first."
188
 
189
+ # mover a GPU SOLO dentro del worker
190
  if torch.cuda.is_available():
191
  dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
192
  model_device = model.to(dtype).to("cuda")
 
199
  temp_image_path = os.path.join(output_path, "temp_image.jpg")
200
  image.save(temp_image_path)
201
 
202
+ size_cfg = {
203
+ "Tiny": (512, 512, False),
204
+ "Small": (640, 640, False),
205
+ "Base": (1024, 1024, False),
206
+ "Large": (1280, 1280, False),
207
+ "Gundam (Recommended)": (1024, 640, True),
208
  }
209
+ base_size, image_size, crop_mode = size_cfg.get(model_size, (1024, 640, True))
210
 
211
+ plain_text = model_device.infer(
212
  tokenizer,
213
  prompt=prompt,
214
  image_file=temp_image_path,
215
  output_path=output_path,
216
+ base_size=base_size,
217
+ image_size=image_size,
218
+ crop_mode=crop_mode,
219
  save_results=True,
220
  test_compress=True,
221
  eval_mode=is_eval_mode,
 
233
  if os.path.exists(image_result_path):
234
  result_image = Image.open(image_result_path); result_image.load()
235
 
236
+ text_result = plain_text if plain_text else markdown_content
237
  return result_image, markdown_content, text_result
238
 
239
  # =========================
240
  # UI (Gradio 5)
241
  # =========================
242
+ with gr.Blocks(title="DeepSeek-OCR + Med42 Instruct", theme=gr.themes.Soft()) as demo:
243
  gr.Markdown(
244
  """
245
+ # DeepSeek-OCR → Chat Clínico con **Med42 Instruct**
246
  1) **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto).
247
+ 2) **Chatea** con **Med42** usando automáticamente el **OCR** como contexto.
248
  *Uso educativo; no reemplaza consejo médico.*
249
  """
250
  )
 
255
  with gr.Row():
256
  with gr.Column(scale=1):
257
  image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard", "webcam"])
258
+ model_size = gr.Dropdown(
259
+ choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
260
+ value="Gundam (Recommended)", label="Model Size"
261
+ )
262
+ task_type = gr.Dropdown(
263
+ choices=["Free OCR", "Convert to Markdown"],
264
+ value="Convert to Markdown", label="Task Type"
265
+ )
266
  eval_mode_checkbox = gr.Checkbox(value=False, label="Enable Evaluation Mode",
267
  info="Solo texto (más rápido). Desmárcalo para ver imagen anotada y markdown.")
268
  submit_btn = gr.Button("Process Image", variant="primary")
 
271
  with gr.Tabs():
272
  with gr.TabItem("Annotated Image"): output_image = gr.Image(interactive=False)
273
  with gr.TabItem("Markdown Preview"): output_markdown = gr.Markdown()
274
+ with gr.TabItem("Markdown Source / Eval"): output_text = gr.Textbox(lines=18, show_copy_button=True, interactive=False)
 
275
  with gr.Row():
276
  md_preview = gr.Textbox(label="Snapshot Markdown OCR", lines=10, interactive=False)
277
+ txt_preview = gr.Textbox_