jorgeiv500 commited on
Commit
c93afa6
·
verified ·
1 Parent(s): 42632ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -117
app.py CHANGED
@@ -1,80 +1,140 @@
1
- # app.py — DeepSeek-OCR + DeepSeek-R1 Medical Mini (GGUF local rápido) — Gradio 5
2
  import os, tempfile, traceback
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
- from transformers import AutoModel, AutoTokenizer
7
  import spaces
8
- from huggingface_hub import hf_hub_download
9
- from llama_cpp import Llama
10
 
11
  # ===============================================================
12
- # CHAT: DeepSeek-R1 Medical Mini SOLO LOCAL (GGUF) para máxima rapidez sin tokens
13
- # - Puedes forzar un archivo con GGUF_REPO / GGUF_FILE
14
- # - Si no especificas, probamos Q4 (rápido) y caemos a f16 si no está
 
15
  # ===============================================================
16
- GGUF_REPO = os.getenv("GGUF_REPO", "mradermacher/DeepSeek-r1-Medical-Mini-GGUF").strip()
17
- GGUF_FILE = os.getenv("GGUF_FILE", "").strip()
18
-
19
- # Orden de preferencia (más rápido -> más pesado). Cambia nombres si tu repo usa otros.
20
- _DEFAULT_CANDIDATES = [
21
- "DeepSeek-r1-Medical-Mini.Q4_K_M.gguf",
22
- "DeepSeek-r1-Medical-Mini.Q4_0.gguf",
23
- "DeepSeek-r1-Medical-Mini.Q5_0.gguf",
24
- "DeepSeek-r1-Medical-Mini.Q8_0.gguf",
25
- "DeepSeek-r1-Medical-Mini.f16.gguf",
26
- ]
27
- GGUF_CANDIDATES = [GGUF_FILE] if GGUF_FILE else _DEFAULT_CANDIDATES
28
-
29
- N_CTX = int(os.getenv("N_CTX", "2048"))
30
- N_THREADS = int(os.getenv("N_THREADS", str(os.cpu_count() or 4)))
31
- N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0")) # Zero/CPU => 0
32
- N_BATCH = int(os.getenv("N_BATCH", "96"))
33
-
34
- _llm = None
35
- def _download_gguf():
36
- last_err = None
37
- for fname in GGUF_CANDIDATES:
38
- try:
39
- path = hf_hub_download(repo_id=GGUF_REPO, filename=fname)
40
- return path, fname
41
- except Exception as e:
42
- last_err = e
43
- raise RuntimeError(f"No se pudo descargar GGUF desde {GGUF_REPO}. Último error: {last_err}")
44
-
45
- def get_llm():
46
- global _llm
47
- if _llm is not None:
48
- return _llm
49
- gguf_path, used = _download_gguf()
50
- print(f"[R1/llama.cpp] usando: {used}")
51
- _llm = Llama(
52
- model_path=gguf_path,
53
- n_ctx=N_CTX,
54
- n_threads=N_THREADS,
55
- n_gpu_layers=N_GPU_LAYERS,
56
- n_batch=N_BATCH,
57
- verbose=False,
58
- )
59
- return _llm
60
-
61
- def _format_chatml(messages):
62
- parts = []
63
- for m in messages:
64
- parts.append(f"<|im_start|>{m.get('role','user')}\n{m.get('content','')}<|im_end|>\n")
65
- parts.append("<|im_start|>assistant\n")
66
- return "".join(parts)
67
-
68
- def r1_chat_local(messages, temperature=0.2, max_tokens=384):
69
- # llama.cpp acepta messages directamente; si tu build no, usa prompt=_format_chatml(messages)
70
- llm = get_llm()
71
- out = llm.create_chat_completion(messages=messages, temperature=temperature, max_tokens=max_tokens)
72
- return out["choices"][0]["message"]["content"]
73
-
74
- # Warmup opcional
75
- if os.getenv("WARMUP", "0") == "1":
76
- try: get_llm()
77
- except Exception: pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  # ===============================================================
80
  # DeepSeek-OCR (intacto) con fallback si no hay FlashAttention2
@@ -152,57 +212,15 @@ def process_image(image, model_size, task_type, is_eval_mode):
152
  text_result = plain_text_result if plain_text_result else markdown_content
153
  return result_image, markdown_content, text_result
154
 
155
- # ===============================================================
156
- # Chat (inyecta OCR) — con R1 local
157
- # ===============================================================
158
- def _truncate(text, max_chars=3000): return (text or "")[:max_chars]
159
-
160
- def _system_prompt():
161
- return ("Eres un asistente clínico educativo. No sustituyes el juicio médico. "
162
- "Usa CONTEXTO_OCR si existe; si falta, pídelo. Evita diagnósticos definitivos.")
163
-
164
- def _ocr_context(ocr_md, ocr_txt): return _truncate(ocr_md) or _truncate(ocr_txt) or ""
165
-
166
- def to_chat_messages(chat_msgs, ocr_md, ocr_txt):
167
- sys = _system_prompt()
168
- ctx = _ocr_context(ocr_md, ocr_txt)
169
- if ctx:
170
- sys += ("\n\n---\n"
171
- "CONTEXTO_OCR (fuente principal; si falta un dato, dilo explícitamente):\n"
172
- f"{ctx}\n---")
173
- msgs = [{"role": "system", "content": sys}]
174
- for m in (chat_msgs or []):
175
- if m.get("role") in ("user", "assistant"):
176
- msgs.append({"role": m["role"], "content": m.get("content", "")})
177
- return msgs
178
-
179
- def r1_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
180
- if not user_msg:
181
- user_msg = "Analiza el CONTEXTO_OCR anterior y responde a partir de ese contenido."
182
- try:
183
- msgs = to_chat_messages(chat_msgs, ocr_md, ocr_txt) + [{"role": "user", "content": user_msg}]
184
- answer = r1_chat_local(msgs, temperature=0.2, max_tokens=512)
185
- updated = (chat_msgs or []) + [{"role": "user", "content": user_msg},
186
- {"role": "assistant", "content": answer}]
187
- return updated, "", gr.update(value="")
188
- except Exception as e:
189
- err = f"{e.__class__.__name__}: {str(e) or repr(e)}"
190
- tb = traceback.format_exc(limit=2)
191
- updated = (chat_msgs or []) + [{"role": "user", "content": user_msg or ""},
192
- {"role": "assistant", "content": f"⚠️ Error LLM: {err}"}]
193
- return updated, "", gr.update(value=f"{err}\n{tb}")
194
-
195
- def clear_chat(): return [], "", gr.update(value="")
196
-
197
  # ===============================================================
198
  # UI (Gradio 5)
199
  # ===============================================================
200
- with gr.Blocks(title="DeepSeek-OCR + R1 Medical (GGUF rápido)", theme=gr.themes.Soft()) as demo:
201
  gr.Markdown(
202
  """
203
- # DeepSeek-OCR → Chat Médico con **DeepSeek-R1 Medical Mini (GGUF local rápido)**
204
  1) **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto).
205
- 2) **Chatea** con **R1 Medical Mini** usando automáticamente el **OCR** como contexto.
206
  *Uso educativo; no reemplaza consejo médico.*
207
  """
208
  )
@@ -231,10 +249,10 @@ with gr.Blocks(title="DeepSeek-OCR + R1 Medical (GGUF rápido)", theme=gr.themes
231
  md_preview = gr.Textbox(label="Snapshot Markdown OCR", lines=10, interactive=False)
232
  txt_preview = gr.Textbox(label="Snapshot Texto OCR", lines=10, interactive=False)
233
 
234
- gr.Markdown("## Chat Clínico (R1 Medical Mini — GGUF local)")
235
  with gr.Row():
236
  with gr.Column(scale=2):
237
- chatbot = gr.Chatbot(label="Asistente OCR (R1 GGUF)", type="messages", height=420)
238
  user_in = gr.Textbox(label="Mensaje", placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)", lines=2)
239
  with gr.Row():
240
  send_btn = gr.Button("Enviar", variant="primary")
@@ -252,7 +270,7 @@ with gr.Blocks(title="DeepSeek-OCR + R1 Medical (GGUF rápido)", theme=gr.themes
252
  outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
253
  )
254
 
255
- send_btn.click(fn=r1_reply, inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
256
  outputs=[chatbot, user_in, error_box])
257
  clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, error_box])
258
 
 
1
+ # app.py — DeepSeek-OCR + BioMedLM (remoto o local) — Gradio 5
2
  import os, tempfile, traceback
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
+ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
7
  import spaces
8
+ from huggingface_hub import hf_hub_download, InferenceClient
 
9
 
10
  # ===============================================================
11
+ # CHAT: BioMedLMRemoto (HF Inference) o Local (Transformers)
12
+ # - Modo remoto: BIO_REMOTE=1 (recomendado en Spaces Zero/CPU)
13
+ # - Modo local: BIO_REMOTE=0 (usa PyTorch; 13B, CPU puede ser lento)
14
+ # - Variables: BIO_MODEL_ID=stanford-crfm/BioMedLM, HF_TOKEN
15
  # ===============================================================
16
+ BIO_REMOTE = os.getenv("BIO_REMOTE", "0") == "1"
17
+ BIO_MODEL_ID = os.getenv("BIO_MODEL_ID", "stanford-crfm/BioMedLM").strip()
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+
20
+ # Parámetros de generación por defecto
21
+ GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
22
+ GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
23
+ GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
24
+ GEN_REP_PENALTY = float(os.getenv("GEN_REP_PENALTY", "1.1"))
25
+
26
+ _bio_model = None
27
+ _bio_tokenizer = None
28
+ _hf_client = None
29
+
30
+ def get_biomedlm():
31
+ """Obtiene el manejador del modelo BioMedLM según modo remoto/local."""
32
+ global _bio_model, _bio_tokenizer, _hf_client
33
+ if BIO_REMOTE:
34
+ if _hf_client is None:
35
+ _hf_client = InferenceClient(model=BIO_MODEL_ID, token=HF_TOKEN)
36
+ return ("remote", _hf_client)
37
+ else:
38
+ if _bio_model is None:
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else (
41
+ torch.float16 if device == "cuda" else torch.float32
42
+ )
43
+ _bio_tokenizer = AutoTokenizer.from_pretrained(BIO_MODEL_ID, use_fast=True)
44
+ _bio_model = AutoModelForCausalLM.from_pretrained(
45
+ BIO_MODEL_ID,
46
+ torch_dtype=dtype,
47
+ )
48
+ _bio_model = _bio_model.to(device)
49
+ return ("local", (_bio_model, _bio_tokenizer))
50
+
51
+ def _system_prompt():
52
+ return ("Eres un asistente clínico educativo. No sustituyes el juicio médico. "
53
+ "Usa CONTEXTO_OCR si existe; si falta, pídelo. Evita diagnósticos definitivos.")
54
+
55
+ def _truncate(text, max_chars=3000):
56
+ return (text or "")[:max_chars]
57
+
58
+ def _ocr_context(ocr_md, ocr_txt):
59
+ return _truncate(ocr_md) or _truncate(ocr_txt) or ""
60
+
61
+ def build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg):
62
+ """Crea un prompt estilo 'instruct' apto para BioMedLM (no es modelo chat)."""
63
+ sys = _system_prompt()
64
+ ctx = _ocr_context(ocr_md, ocr_txt)
65
+
66
+ history_lines = []
67
+ for m in (chat_msgs or []):
68
+ role = m.get("role")
69
+ content = (m.get("content") or "").strip()
70
+ if not content:
71
+ continue
72
+ if role == "user":
73
+ history_lines.append(f"User: {content}")
74
+ elif role == "assistant":
75
+ history_lines.append(f"Assistant: {content}")
76
+
77
+ if user_msg:
78
+ history_lines.append(f"User: {user_msg}")
79
+
80
+ convo = "\n".join(history_lines).strip()
81
+ prompt = f"### System\n{sys}\n\n"
82
+ if ctx:
83
+ prompt += f"### Context (OCR)\n{ctx}\n\n"
84
+ prompt += f"### Conversation\n{convo}\nAssistant:"
85
+ return prompt
86
+
87
+ def biomedlm_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
88
+ """Genera respuesta con BioMedLM (remoto o local)."""
89
+ try:
90
+ if not user_msg:
91
+ user_msg = "Analiza el CONTEXTO_OCR anterior y responde a partir de ese contenido."
92
+ prompt = build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg)
93
+ mode, handle = get_biomedlm()
94
+
95
+ if mode == "remote":
96
+ # HF Inference (text-generation)
97
+ out = handle.text_generation(
98
+ prompt,
99
+ max_new_tokens=GEN_MAX_NEW_TOKENS,
100
+ temperature=GEN_TEMPERATURE,
101
+ top_p=GEN_TOP_P,
102
+ repetition_penalty=GEN_REP_PENALTY,
103
+ # Paradas suaves; evita que el modelo “rompa” secciones
104
+ stop_sequences=["\nUser:", "### System", "### Context", "### Conversation"]
105
+ )
106
+ answer = out
107
+ else:
108
+ # Local (PyTorch)
109
+ model, tok = handle
110
+ inputs = tok(prompt, return_tensors="pt").to(model.device)
111
+ gen_ids = model.generate(
112
+ **inputs,
113
+ do_sample=True,
114
+ temperature=GEN_TEMPERATURE,
115
+ top_p=GEN_TOP_P,
116
+ repetition_penalty=GEN_REP_PENALTY,
117
+ max_new_tokens=GEN_MAX_NEW_TOKENS,
118
+ eos_token_id=tok.eos_token_id,
119
+ )
120
+ answer = tok.decode(gen_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
121
+
122
+ updated = (chat_msgs or []) + [
123
+ {"role": "user", "content": user_msg},
124
+ {"role": "assistant", "content": answer.strip()}
125
+ ]
126
+ return updated, "", gr.update(value="")
127
+ except Exception as e:
128
+ err = f"{e.__class__.__name__}: {str(e) or repr(e)}"
129
+ tb = traceback.format_exc(limit=2)
130
+ updated = (chat_msgs or []) + [
131
+ {"role": "user", "content": user_msg or ""},
132
+ {"role": "assistant", "content": f"⚠️ Error LLM: {err}"}
133
+ ]
134
+ return updated, "", gr.update(value=f"{err}\n{tb}")
135
+
136
+ def clear_chat():
137
+ return [], "", gr.update(value="")
138
 
139
  # ===============================================================
140
  # DeepSeek-OCR (intacto) con fallback si no hay FlashAttention2
 
212
  text_result = plain_text_result if plain_text_result else markdown_content
213
  return result_image, markdown_content, text_result
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  # ===============================================================
216
  # UI (Gradio 5)
217
  # ===============================================================
218
+ with gr.Blocks(title="DeepSeek-OCR + BioMedLM", theme=gr.themes.Soft()) as demo:
219
  gr.Markdown(
220
  """
221
+ # DeepSeek-OCR → Chat Médico con **BioMedLM**
222
  1) **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto).
223
+ 2) **Chatea** con **BioMedLM** usando automáticamente el **OCR** como contexto.
224
  *Uso educativo; no reemplaza consejo médico.*
225
  """
226
  )
 
249
  md_preview = gr.Textbox(label="Snapshot Markdown OCR", lines=10, interactive=False)
250
  txt_preview = gr.Textbox(label="Snapshot Texto OCR", lines=10, interactive=False)
251
 
252
+ gr.Markdown("## Chat Clínico (BioMedLM)")
253
  with gr.Row():
254
  with gr.Column(scale=2):
255
+ chatbot = gr.Chatbot(label="Asistente OCR (BioMedLM)", type="messages", height=420)
256
  user_in = gr.Textbox(label="Mensaje", placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)", lines=2)
257
  with gr.Row():
258
  send_btn = gr.Button("Enviar", variant="primary")
 
270
  outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
271
  )
272
 
273
+ send_btn.click(fn=biomedlm_reply, inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
274
  outputs=[chatbot, user_in, error_box])
275
  clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, error_box])
276