jorgeiv500 commited on
Commit
4a2190b
·
verified ·
1 Parent(s): 6bee325

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -151
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — DeepSeek-OCR + BioMedLM con fixes para StopIteration (HF) y ZeroGPU — Gradio 5
2
  import os, tempfile, traceback, json
3
  import gradio as gr
4
  import torch
@@ -6,44 +6,40 @@ from PIL import Image
6
  from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
7
  import spaces
8
  from huggingface_hub import InferenceClient
9
- import requests # Fallback HTTP directo a HF si falla InferenceClient
10
 
11
- # ===============================================================
12
  # CONFIG (env)
13
- # ===============================================================
14
- BIO_REMOTE = os.getenv("BIO_REMOTE", "1") == "1" # Recomendado en Spaces ZeroGPU
15
  BIO_MODEL_ID = os.getenv("BIO_MODEL_ID", "stanford-crfm/BioMedLM").strip()
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
- HF_PROVIDER = os.getenv("HF_PROVIDER", "hf-inference").strip() # fuerza proveedor y evita StopIteration
18
- BIO_FALLBACK_REMOTE = os.getenv("BIO_FALLBACK_REMOTE", "1") == "1" # Si local falla => intenta remoto
19
 
20
  GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
21
  GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
22
  GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
23
  GEN_REP_PENALTY = float(os.getenv("GEN_REP_PENALTY", "1.1"))
24
- GEN_TIMEOUT = int(os.getenv("GEN_TIMEOUT", "60")) # seg. para llamadas remotas
25
 
26
  STOP_SEQS = ["\nUser:", "### System", "### Context", "### Conversation"]
27
 
28
- # Caches (no tocan CUDA en el proceso principal)
29
  _hf_client = None
30
  _bio_local_cache = {"model": None, "tokenizer": None}
31
 
32
- # ===============================================================
33
- # PROMPTS / CHAT HELPERS
34
- # ===============================================================
35
- def _truncate(text, max_chars=3000):
36
- return (text or "")[:max_chars]
37
 
38
  def _system_prompt():
39
  return ("Eres un asistente clínico educativo. No sustituyes el juicio médico. "
40
  "Usa CONTEXTO_OCR si existe; si falta, pídelo. Evita diagnósticos definitivos.")
41
 
42
- def _ocr_context(ocr_md, ocr_txt):
43
- return _truncate(ocr_md) or _truncate(ocr_txt) or ""
44
 
45
  def build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg):
46
- """Prompt estilo instruct para BioMedLM (no chat nativo)."""
47
  sys = _system_prompt()
48
  ctx = _ocr_context(ocr_md, ocr_txt)
49
 
@@ -68,89 +64,85 @@ def build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg):
68
  prompt += f"### Conversation\n{convo}\nAssistant:"
69
  return prompt
70
 
71
- # ===============================================================
72
- # BIO: remoto/local adapters (NO CUDA en main)
73
- # ===============================================================
74
  def get_biomedlm():
75
- """Decide modo. No cargar modelos ni tocar CUDA aquí."""
76
  global _hf_client
77
  if BIO_REMOTE:
78
  if _hf_client is None:
79
- # Fuerza provider para evitar StopIteration en algunas versiones de huggingface_hub
80
- _hf_client = InferenceClient(model=BIO_MODEL_ID, token=HF_TOKEN, provider=HF_PROVIDER)
 
 
 
 
 
81
  return ("remote", _hf_client)
82
  return ("local", None)
83
 
84
- def _hf_text_generation_raw(model_id: str, prompt: str,
85
- temperature: float, top_p: float, rep_penalty: float,
86
- max_new_tokens: int, stop: list, timeout: int) -> str:
87
- """
88
- Fallback directo a la API de Inference (HTTP) si falla InferenceClient.text_generation
89
- Maneja respuestas tanto de serverless como TGI.
90
- """
91
- url = f"https://api-inference.huggingface.co/models/{model_id}"
92
  headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
93
  payload = {
94
- "inputs": prompt,
95
- "parameters": {
96
- "max_new_tokens": max_new_tokens,
97
- "temperature": temperature,
98
- "top_p": top_p,
99
- "repetition_penalty": rep_penalty,
100
- "stop": stop,
101
- "return_full_text": False
102
- },
103
- "options": {"use_cache": False, "wait_for_model": True}
104
  }
105
- r = requests.post(url, headers=headers, json=payload, timeout=timeout)
106
- if r.status_code == 200:
107
- data = r.json()
108
- # Respuesta puede ser lista con {generated_text} o dict TGI-like
109
- if isinstance(data, list) and len(data) > 0 and "generated_text" in data[0]:
110
- return data[0]["generated_text"]
111
- # Algunas variantes devuelven dict con 'generated_text' o 'text'
112
- if isinstance(data, dict):
113
- if "generated_text" in data:
114
- return data["generated_text"]
115
- if "text" in data:
116
- return data["text"]
117
- # Fallback a string
118
- return json.dumps(data)[:4000]
119
- else:
120
- raise RuntimeError(f"HTTP {r.status_code}: {r.text[:1000]}")
 
 
 
 
 
 
 
121
 
122
  def call_biomedlm_remote(prompt: str) -> (str, str):
123
  """
124
- Intenta usar InferenceClient.text_generation; si levanta StopIteration/otros,
125
- cae a HTTP raw. Retorna (respuesta, debug_msg)
126
  """
127
  client = get_biomedlm()[1]
128
  try:
129
- out = client.text_generation(
130
- prompt,
131
- max_new_tokens=GEN_MAX_NEW_TOKENS,
 
132
  temperature=GEN_TEMPERATURE,
133
  top_p=GEN_TOP_P,
134
- repetition_penalty=GEN_REP_PENALTY,
135
- stop_sequences=STOP_SEQS,
136
- details=False, # mantener string plano
137
- stream=False,
138
- timeout=GEN_TIMEOUT,
139
  )
140
- answer = out.strip() if isinstance(out, str) else str(out)
141
  return answer, ""
142
  except Exception as e:
143
- # Fallback a HTTP
144
  try:
145
- answer = _hf_text_generation_raw(
146
- BIO_MODEL_ID, prompt,
147
- GEN_TEMPERATURE, GEN_TOP_P, GEN_REP_PENALTY,
148
- GEN_MAX_NEW_TOKENS, STOP_SEQS, GEN_TIMEOUT
149
- ).strip()
150
- dbg = f"[Fallback HTTP HF] {e.__class__.__name__}: {str(e) or repr(e)}"
151
- return answer, dbg
152
  except Exception as e2:
153
- raise RuntimeError(f"Remote generation failed: {e.__class__.__name__}: {e} | HTTP fallback: {e2.__class__.__name__}: {e2}")
 
 
154
 
155
  @spaces.GPU
156
  def biomedlm_infer_local(prompt: str,
@@ -158,26 +150,21 @@ def biomedlm_infer_local(prompt: str,
158
  top_p=0.9,
159
  rep_penalty=1.1,
160
  max_new_tokens=512) -> str:
161
- """Ejecución local en worker GPU; captura errores y los devuelve con prefijo ERR::"""
162
  try:
163
- # Carga perezosa dentro del worker GPU
164
  if _bio_local_cache["model"] is None:
165
  tok = AutoTokenizer.from_pretrained(BIO_MODEL_ID, use_fast=True)
166
- if torch.cuda.is_available():
167
- dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
168
- else:
169
- dtype = torch.float32
170
-
171
  model = AutoModelForCausalLM.from_pretrained(BIO_MODEL_ID, torch_dtype=dtype)
172
  if torch.cuda.is_available():
173
  model = model.to("cuda")
174
-
175
  _bio_local_cache["model"] = model.eval()
176
  _bio_local_cache["tokenizer"] = tok
177
 
178
  model = _bio_local_cache["model"]
179
  tok = _bio_local_cache["tokenizer"]
180
-
181
  inputs = tok(prompt, return_tensors="pt")
182
  if torch.cuda.is_available():
183
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
@@ -193,21 +180,16 @@ def biomedlm_infer_local(prompt: str,
193
  )
194
  text = tok.decode(gen_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
195
  return "OK::" + text.strip()
196
-
197
  except Exception as e:
198
- err_cls = e.__class__.__name__
199
- return f"ERR::[{err_cls}] {str(e) or repr(e)}"
200
 
201
  def biomedlm_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
202
- """Wrapper que decide remoto/local y maneja fallback + mensajes de error explícitos."""
203
  try:
204
  if not user_msg:
205
  user_msg = "Analiza el CONTEXTO_OCR anterior y responde a partir de ese contenido."
206
  prompt = build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg)
207
 
208
- mode, _handle = get_biomedlm()
209
-
210
- # Preferido: remoto (evita límites ZeroGPU y CUDA en main)
211
  if mode == "remote":
212
  answer, dbg = call_biomedlm_remote(prompt)
213
  updated = (chat_msgs or []) + [
@@ -224,7 +206,6 @@ def biomedlm_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
224
  rep_penalty=GEN_REP_PENALTY,
225
  max_new_tokens=GEN_MAX_NEW_TOKENS
226
  )
227
-
228
  if res.startswith("OK::"):
229
  answer = res[4:]
230
  updated = (chat_msgs or []) + [
@@ -233,23 +214,14 @@ def biomedlm_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
233
  ]
234
  return updated, "", gr.update(value="")
235
  else:
236
- # Error local: mensaje detallado viene en res
237
  err_msg = res[5:] if res.startswith("ERR::") else res
238
-
239
- # Fallback automático a remoto si está permitido
240
- if BIO_FALLBACK_REMOTE:
241
- answer2, dbg2 = call_biomedlm_remote(prompt)
242
- updated = (chat_msgs or []) + [
243
- {"role": "user", "content": user_msg},
244
- {"role": "assistant", "content": answer2}
245
- ]
246
- return updated, "", gr.update(value=f"[Local->Remoto fallback]\n{err_msg}\n{dbg2}")
247
- else:
248
- updated = (chat_msgs or []) + [
249
- {"role": "user", "content": user_msg},
250
- {"role": "assistant", "content": "⚠️ Error LLM (local). Revisa el panel de debug."}
251
- ]
252
- return updated, "", gr.update(value=err_msg)
253
 
254
  except Exception as e:
255
  err = f"{e.__class__.__name__}: {str(e) or repr(e)}"
@@ -260,32 +232,24 @@ def biomedlm_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
260
  ]
261
  return updated, "", gr.update(value=f"{err}\n{tb}")
262
 
263
- def clear_chat():
264
- return [], "", gr.update(value="")
265
 
266
- # ===============================================================
267
- # DeepSeek-OCR (intacto) con fallback si no hay FlashAttention2
268
- # * NO CUDA hasta @spaces.GPU
269
- # ===============================================================
270
  def _load_ocr_model():
271
  model_name = "deepseek-ai/DeepSeek-OCR"
272
  ocr_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
273
  attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")
274
  try:
275
  ocr_model = AutoModel.from_pretrained(
276
- model_name,
277
- _attn_implementation=attn_impl,
278
- trust_remote_code=True,
279
- use_safetensors=True
280
  ).eval()
281
  return ocr_tokenizer, ocr_model
282
  except Exception as e:
283
  if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
284
  ocr_model = AutoModel.from_pretrained(
285
- model_name,
286
- _attn_implementation="eager",
287
- trust_remote_code=True,
288
- use_safetensors=True
289
  ).eval()
290
  return ocr_tokenizer, ocr_model
291
  raise
@@ -297,7 +261,6 @@ def process_image(image, model_size, task_type, is_eval_mode):
297
  if image is None:
298
  return None, "Please upload an image first.", "Please upload an image first."
299
 
300
- # Mover a GPU SOLO dentro del worker
301
  if torch.cuda.is_available():
302
  dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
303
  model_device = model.to(dtype).to("cuda")
@@ -347,9 +310,9 @@ def process_image(image, model_size, task_type, is_eval_mode):
347
  text_result = plain_text_result if plain_text_result else markdown_content
348
  return result_image, markdown_content, text_result
349
 
350
- # ===============================================================
351
  # UI (Gradio 5)
352
- # ===============================================================
353
  with gr.Blocks(title="DeepSeek-OCR + BioMedLM", theme=gr.themes.Soft()) as demo:
354
  gr.Markdown(
355
  """
@@ -366,29 +329,18 @@ with gr.Blocks(title="DeepSeek-OCR + BioMedLM", theme=gr.themes.Soft()) as demo:
366
  with gr.Row():
367
  with gr.Column(scale=1):
368
  image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard", "webcam"])
369
- model_size = gr.Dropdown(
370
- choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
371
- value="Gundam (Recommended)",
372
- label="Model Size"
373
- )
374
- task_type = gr.Dropdown(
375
- choices=["Free OCR", "Convert to Markdown"],
376
- value="Convert to Markdown",
377
- label="Task Type"
378
- )
379
- eval_mode_checkbox = gr.Checkbox(
380
- value=False,
381
- label="Enable Evaluation Mode",
382
- info="Solo texto (más rápido). Desmárcalo para ver imagen anotada y markdown."
383
- )
384
  submit_btn = gr.Button("Process Image", variant="primary")
385
 
386
  with gr.Column(scale=2):
387
  with gr.Tabs():
388
- with gr.TabItem("Annotated Image"):
389
- output_image = gr.Image(interactive=False)
390
- with gr.TabItem("Markdown Preview"):
391
- output_markdown = gr.Markdown()
392
  with gr.TabItem("Markdown Source (or Eval Output)"):
393
  output_text = gr.Textbox(lines=18, show_copy_button=True, interactive=False)
394
  with gr.Row():
@@ -416,11 +368,8 @@ with gr.Blocks(title="DeepSeek-OCR + BioMedLM", theme=gr.themes.Soft()) as demo:
416
  outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
417
  )
418
 
419
- send_btn.click(
420
- fn=biomedlm_reply,
421
- inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
422
- outputs=[chatbot, user_in, error_box]
423
- )
424
  clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, error_box])
425
 
426
  if __name__ == "__main__":
 
1
+ # app.py — DeepSeek-OCR + BioMedLM (HF router fix + ZeroGPU-safe) — Gradio 5
2
  import os, tempfile, traceback, json
3
  import gradio as gr
4
  import torch
 
6
  from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
7
  import spaces
8
  from huggingface_hub import InferenceClient
9
+ import requests
10
 
11
+ # =========================
12
  # CONFIG (env)
13
+ # =========================
14
+ BIO_REMOTE = os.getenv("BIO_REMOTE", "1") == "1" # recomendado en Spaces ZeroGPU
15
  BIO_MODEL_ID = os.getenv("BIO_MODEL_ID", "stanford-crfm/BioMedLM").strip()
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
+ HF_PROVIDER = os.getenv("HF_PROVIDER", "hf-inference").strip()
 
18
 
19
  GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
20
  GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
21
  GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
22
  GEN_REP_PENALTY = float(os.getenv("GEN_REP_PENALTY", "1.1"))
23
+ GEN_TIMEOUT = int(os.getenv("GEN_TIMEOUT", "60")) # s
24
 
25
  STOP_SEQS = ["\nUser:", "### System", "### Context", "### Conversation"]
26
 
27
+ # Caches (sin tocar CUDA en el proceso principal)
28
  _hf_client = None
29
  _bio_local_cache = {"model": None, "tokenizer": None}
30
 
31
+ # =========================
32
+ # Prompt helpers
33
+ # =========================
34
+ def _truncate(text, max_chars=3000): return (text or "")[:max_chars]
 
35
 
36
  def _system_prompt():
37
  return ("Eres un asistente clínico educativo. No sustituyes el juicio médico. "
38
  "Usa CONTEXTO_OCR si existe; si falta, pídelo. Evita diagnósticos definitivos.")
39
 
40
+ def _ocr_context(ocr_md, ocr_txt): return _truncate(ocr_md) or _truncate(ocr_txt) or ""
 
41
 
42
  def build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg):
 
43
  sys = _system_prompt()
44
  ctx = _ocr_context(ocr_md, ocr_txt)
45
 
 
64
  prompt += f"### Conversation\n{convo}\nAssistant:"
65
  return prompt
66
 
67
+ # =========================
68
+ # BioMedLM remoto/local
69
+ # =========================
70
  def get_biomedlm():
71
+ """Decidir modo. No tocar CUDA aquí."""
72
  global _hf_client
73
  if BIO_REMOTE:
74
  if _hf_client is None:
75
+ # timeout va en el constructor del cliente (no en text_generation)
76
+ _hf_client = InferenceClient(
77
+ model=BIO_MODEL_ID,
78
+ provider=HF_PROVIDER,
79
+ token=HF_TOKEN,
80
+ timeout=GEN_TIMEOUT, # ← así es correcto
81
+ )
82
  return ("remote", _hf_client)
83
  return ("local", None)
84
 
85
+ def _hf_http_chat(prompt: str) -> str:
86
+ """Fallback HTTP al router HF (dos rutas posibles)."""
 
 
 
 
 
 
87
  headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
88
  payload = {
89
+ "model": BIO_MODEL_ID,
90
+ "messages": [{"role": "user", "content": prompt}],
91
+ "max_tokens": GEN_MAX_NEW_TOKENS,
92
+ "temperature": GEN_TEMPERATURE,
93
+ "top_p": GEN_TOP_P,
94
+ "stop": STOP_SEQS,
 
 
 
 
95
  }
96
+
97
+ # 1) ruta OpenAI-compat
98
+ urls = [
99
+ "https://router.huggingface.co/v1/chat/completions",
100
+ # 2) algunos clientes piden prefijo /hf-inference
101
+ "https://router.huggingface.co/hf-inference/v1/chat/completions",
102
+ ]
103
+ last_exc = None
104
+ for url in urls:
105
+ try:
106
+ r = requests.post(url, headers=headers, json=payload, timeout=GEN_TIMEOUT)
107
+ if r.status_code == 200:
108
+ data = r.json()
109
+ # OpenAI-like response
110
+ if isinstance(data, dict) and "choices" in data and data["choices"]:
111
+ msg = data["choices"][0].get("message") or {}
112
+ return (msg.get("content") or "").strip()
113
+ return json.dumps(data)[:4000]
114
+ # si 410 en api vieja, seguir intentando
115
+ last_exc = RuntimeError(f"HTTP {r.status_code}: {r.text[:800]}")
116
+ except Exception as e:
117
+ last_exc = e
118
+ raise last_exc or RuntimeError("HF router error")
119
 
120
  def call_biomedlm_remote(prompt: str) -> (str, str):
121
  """
122
+ Usa chat.completions.create (OpenAI-like). Si falla, cae a HTTP router.
123
+ Retorna (respuesta, debug_msg)
124
  """
125
  client = get_biomedlm()[1]
126
  try:
127
+ resp = client.chat.completions.create(
128
+ model=BIO_MODEL_ID,
129
+ messages=[{"role": "user", "content": prompt}],
130
+ max_tokens=GEN_MAX_NEW_TOKENS,
131
  temperature=GEN_TEMPERATURE,
132
  top_p=GEN_TOP_P,
133
+ stop=STOP_SEQS,
 
 
 
 
134
  )
135
+ answer = (resp.choices[0].message.content or "").strip()
136
  return answer, ""
137
  except Exception as e:
138
+ # Fallback HTTP al router nuevo
139
  try:
140
+ answer = _hf_http_chat(prompt)
141
+ return answer, f"[Fallback HTTP router] {e.__class__.__name__}: {e}"
 
 
 
 
 
142
  except Exception as e2:
143
+ raise RuntimeError(
144
+ f"Remote generation failed: {e.__class__.__name__}: {e} | HTTP fallback: {e2.__class__.__name__}: {e2}"
145
+ )
146
 
147
  @spaces.GPU
148
  def biomedlm_infer_local(prompt: str,
 
150
  top_p=0.9,
151
  rep_penalty=1.1,
152
  max_new_tokens=512) -> str:
153
+ """Ejecución local en worker GPU; devuelve OK:: o ERR::..."""
154
  try:
 
155
  if _bio_local_cache["model"] is None:
156
  tok = AutoTokenizer.from_pretrained(BIO_MODEL_ID, use_fast=True)
157
+ dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else (
158
+ torch.float16 if torch.cuda.is_available() else torch.float32
159
+ )
 
 
160
  model = AutoModelForCausalLM.from_pretrained(BIO_MODEL_ID, torch_dtype=dtype)
161
  if torch.cuda.is_available():
162
  model = model.to("cuda")
 
163
  _bio_local_cache["model"] = model.eval()
164
  _bio_local_cache["tokenizer"] = tok
165
 
166
  model = _bio_local_cache["model"]
167
  tok = _bio_local_cache["tokenizer"]
 
168
  inputs = tok(prompt, return_tensors="pt")
169
  if torch.cuda.is_available():
170
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
 
180
  )
181
  text = tok.decode(gen_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
182
  return "OK::" + text.strip()
 
183
  except Exception as e:
184
+ return f"ERR::[{e.__class__.__name__}] {str(e) or repr(e)}"
 
185
 
186
  def biomedlm_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
 
187
  try:
188
  if not user_msg:
189
  user_msg = "Analiza el CONTEXTO_OCR anterior y responde a partir de ese contenido."
190
  prompt = build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg)
191
 
192
+ mode, _ = get_biomedlm()
 
 
193
  if mode == "remote":
194
  answer, dbg = call_biomedlm_remote(prompt)
195
  updated = (chat_msgs or []) + [
 
206
  rep_penalty=GEN_REP_PENALTY,
207
  max_new_tokens=GEN_MAX_NEW_TOKENS
208
  )
 
209
  if res.startswith("OK::"):
210
  answer = res[4:]
211
  updated = (chat_msgs or []) + [
 
214
  ]
215
  return updated, "", gr.update(value="")
216
  else:
 
217
  err_msg = res[5:] if res.startswith("ERR::") else res
218
+ # fallback a remoto si se permite
219
+ answer2, dbg2 = call_biomedlm_remote(prompt)
220
+ updated = (chat_msgs or []) + [
221
+ {"role": "user", "content": user_msg},
222
+ {"role": "assistant", "content": answer2}
223
+ ]
224
+ return updated, "", gr.update(value=f"[Local->Remoto fallback]\n{err_msg}\n{dbg2}")
 
 
 
 
 
 
 
 
225
 
226
  except Exception as e:
227
  err = f"{e.__class__.__name__}: {str(e) or repr(e)}"
 
232
  ]
233
  return updated, "", gr.update(value=f"{err}\n{tb}")
234
 
235
+ def clear_chat(): return [], "", gr.update(value="")
 
236
 
237
+ # =========================
238
+ # DeepSeek-OCR (sin CUDA en main)
239
+ # =========================
 
240
  def _load_ocr_model():
241
  model_name = "deepseek-ai/DeepSeek-OCR"
242
  ocr_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
243
  attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")
244
  try:
245
  ocr_model = AutoModel.from_pretrained(
246
+ model_name, _attn_implementation=attn_impl, trust_remote_code=True, use_safetensors=True
 
 
 
247
  ).eval()
248
  return ocr_tokenizer, ocr_model
249
  except Exception as e:
250
  if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
251
  ocr_model = AutoModel.from_pretrained(
252
+ model_name, _attn_implementation="eager", trust_remote_code=True, use_safetensors=True
 
 
 
253
  ).eval()
254
  return ocr_tokenizer, ocr_model
255
  raise
 
261
  if image is None:
262
  return None, "Please upload an image first.", "Please upload an image first."
263
 
 
264
  if torch.cuda.is_available():
265
  dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
266
  model_device = model.to(dtype).to("cuda")
 
310
  text_result = plain_text_result if plain_text_result else markdown_content
311
  return result_image, markdown_content, text_result
312
 
313
+ # =========================
314
  # UI (Gradio 5)
315
+ # =========================
316
  with gr.Blocks(title="DeepSeek-OCR + BioMedLM", theme=gr.themes.Soft()) as demo:
317
  gr.Markdown(
318
  """
 
329
  with gr.Row():
330
  with gr.Column(scale=1):
331
  image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard", "webcam"])
332
+ model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
333
+ value="Gundam (Recommended)", label="Model Size")
334
+ task_type = gr.Dropdown(choices=["Free OCR", "Convert to Markdown"],
335
+ value="Convert to Markdown", label="Task Type")
336
+ eval_mode_checkbox = gr.Checkbox(value=False, label="Enable Evaluation Mode",
337
+ info="Solo texto (más rápido). Desmárcalo para ver imagen anotada y markdown.")
 
 
 
 
 
 
 
 
 
338
  submit_btn = gr.Button("Process Image", variant="primary")
339
 
340
  with gr.Column(scale=2):
341
  with gr.Tabs():
342
+ with gr.TabItem("Annotated Image"): output_image = gr.Image(interactive=False)
343
+ with gr.TabItem("Markdown Preview"): output_markdown = gr.Markdown()
 
 
344
  with gr.TabItem("Markdown Source (or Eval Output)"):
345
  output_text = gr.Textbox(lines=18, show_copy_button=True, interactive=False)
346
  with gr.Row():
 
368
  outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
369
  )
370
 
371
+ send_btn.click(fn=biomedlm_reply, inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
372
+ outputs=[chatbot, user_in, error_box])
 
 
 
373
  clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, error_box])
374
 
375
  if __name__ == "__main__":