jorgeiv500 commited on
Commit
6bee325
·
verified ·
1 Parent(s): 930b9fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -29
app.py CHANGED
@@ -1,11 +1,12 @@
1
- # app.py — DeepSeek-OCR + BioMedLM con ZeroGPU-safe y fallback remoto — Gradio 5
2
- import os, tempfile, traceback
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
  from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
7
  import spaces
8
  from huggingface_hub import InferenceClient
 
9
 
10
  # ===============================================================
11
  # CONFIG (env)
@@ -13,12 +14,16 @@ from huggingface_hub import InferenceClient
13
  BIO_REMOTE = os.getenv("BIO_REMOTE", "1") == "1" # Recomendado en Spaces ZeroGPU
14
  BIO_MODEL_ID = os.getenv("BIO_MODEL_ID", "stanford-crfm/BioMedLM").strip()
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
- BIO_FALLBACK_REMOTE = os.getenv("BIO_FALLBACK_REMOTE", "1") == "1" # Si falla local => intenta remoto
 
17
 
18
  GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
19
  GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
20
  GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
21
  GEN_REP_PENALTY = float(os.getenv("GEN_REP_PENALTY", "1.1"))
 
 
 
22
 
23
  # Caches (no tocan CUDA en el proceso principal)
24
  _hf_client = None
@@ -71,10 +76,82 @@ def get_biomedlm():
71
  global _hf_client
72
  if BIO_REMOTE:
73
  if _hf_client is None:
74
- _hf_client = InferenceClient(model=BIO_MODEL_ID, token=HF_TOKEN)
 
75
  return ("remote", _hf_client)
76
  return ("local", None)
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  @spaces.GPU
79
  def biomedlm_infer_local(prompt: str,
80
  temperature=0.2,
@@ -118,7 +195,6 @@ def biomedlm_infer_local(prompt: str,
118
  return "OK::" + text.strip()
119
 
120
  except Exception as e:
121
- # Devolver mensaje de error rico (no levantar excepción para que ZeroGPU no lo opaque)
122
  err_cls = e.__class__.__name__
123
  return f"ERR::[{err_cls}] {str(e) or repr(e)}"
124
 
@@ -129,24 +205,16 @@ def biomedlm_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
129
  user_msg = "Analiza el CONTEXTO_OCR anterior y responde a partir de ese contenido."
130
  prompt = build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg)
131
 
132
- mode, handle = get_biomedlm()
133
 
134
- # Preferido: remoto (evita límites ZeroGPU)
135
  if mode == "remote":
136
- out = handle.text_generation(
137
- prompt,
138
- max_new_tokens=GEN_MAX_NEW_TOKENS,
139
- temperature=GEN_TEMPERATURE,
140
- top_p=GEN_TOP_P,
141
- repetition_penalty=GEN_REP_PENALTY,
142
- stop_sequences=["\nUser:", "### System", "### Context", "### Conversation"]
143
- )
144
- answer = out.strip() if isinstance(out, str) else str(out)
145
  updated = (chat_msgs or []) + [
146
  {"role": "user", "content": user_msg},
147
  {"role": "assistant", "content": answer}
148
  ]
149
- return updated, "", gr.update(value="")
150
 
151
  # Local (ZeroGPU)
152
  res = biomedlm_infer_local(
@@ -170,22 +238,12 @@ def biomedlm_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
170
 
171
  # Fallback automático a remoto si está permitido
172
  if BIO_FALLBACK_REMOTE:
173
- mode2, handle2 = ("remote", InferenceClient(model=BIO_MODEL_ID, token=HF_TOKEN))
174
- out2 = handle2.text_generation(
175
- prompt,
176
- max_new_tokens=GEN_MAX_NEW_TOKENS,
177
- temperature=GEN_TEMPERATURE,
178
- top_p=GEN_TOP_P,
179
- repetition_penalty=GEN_REP_PENALTY,
180
- stop_sequences=["\nUser:", "### System", "### Context", "### Conversation"]
181
- )
182
- answer2 = out2.strip() if isinstance(out2, str) else str(out2)
183
  updated = (chat_msgs or []) + [
184
  {"role": "user", "content": user_msg},
185
  {"role": "assistant", "content": answer2}
186
  ]
187
- # Enviar detalle al panel de debug
188
- return updated, "", gr.update(value=f"[Local->Remoto fallback]\n{err_msg}")
189
  else:
190
  updated = (chat_msgs or []) + [
191
  {"role": "user", "content": user_msg},
 
1
+ # app.py — DeepSeek-OCR + BioMedLM con fixes para StopIteration (HF) y ZeroGPU — Gradio 5
2
+ import os, tempfile, traceback, json
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
  from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
7
  import spaces
8
  from huggingface_hub import InferenceClient
9
+ import requests # Fallback HTTP directo a HF si falla InferenceClient
10
 
11
  # ===============================================================
12
  # CONFIG (env)
 
14
  BIO_REMOTE = os.getenv("BIO_REMOTE", "1") == "1" # Recomendado en Spaces ZeroGPU
15
  BIO_MODEL_ID = os.getenv("BIO_MODEL_ID", "stanford-crfm/BioMedLM").strip()
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
+ HF_PROVIDER = os.getenv("HF_PROVIDER", "hf-inference").strip() # fuerza proveedor y evita StopIteration
18
+ BIO_FALLBACK_REMOTE = os.getenv("BIO_FALLBACK_REMOTE", "1") == "1" # Si local falla => intenta remoto
19
 
20
  GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
21
  GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
22
  GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
23
  GEN_REP_PENALTY = float(os.getenv("GEN_REP_PENALTY", "1.1"))
24
+ GEN_TIMEOUT = int(os.getenv("GEN_TIMEOUT", "60")) # seg. para llamadas remotas
25
+
26
+ STOP_SEQS = ["\nUser:", "### System", "### Context", "### Conversation"]
27
 
28
  # Caches (no tocan CUDA en el proceso principal)
29
  _hf_client = None
 
76
  global _hf_client
77
  if BIO_REMOTE:
78
  if _hf_client is None:
79
+ # Fuerza provider para evitar StopIteration en algunas versiones de huggingface_hub
80
+ _hf_client = InferenceClient(model=BIO_MODEL_ID, token=HF_TOKEN, provider=HF_PROVIDER)
81
  return ("remote", _hf_client)
82
  return ("local", None)
83
 
84
+ def _hf_text_generation_raw(model_id: str, prompt: str,
85
+ temperature: float, top_p: float, rep_penalty: float,
86
+ max_new_tokens: int, stop: list, timeout: int) -> str:
87
+ """
88
+ Fallback directo a la API de Inference (HTTP) si falla InferenceClient.text_generation
89
+ Maneja respuestas tanto de serverless como TGI.
90
+ """
91
+ url = f"https://api-inference.huggingface.co/models/{model_id}"
92
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
93
+ payload = {
94
+ "inputs": prompt,
95
+ "parameters": {
96
+ "max_new_tokens": max_new_tokens,
97
+ "temperature": temperature,
98
+ "top_p": top_p,
99
+ "repetition_penalty": rep_penalty,
100
+ "stop": stop,
101
+ "return_full_text": False
102
+ },
103
+ "options": {"use_cache": False, "wait_for_model": True}
104
+ }
105
+ r = requests.post(url, headers=headers, json=payload, timeout=timeout)
106
+ if r.status_code == 200:
107
+ data = r.json()
108
+ # Respuesta puede ser lista con {generated_text} o dict TGI-like
109
+ if isinstance(data, list) and len(data) > 0 and "generated_text" in data[0]:
110
+ return data[0]["generated_text"]
111
+ # Algunas variantes devuelven dict con 'generated_text' o 'text'
112
+ if isinstance(data, dict):
113
+ if "generated_text" in data:
114
+ return data["generated_text"]
115
+ if "text" in data:
116
+ return data["text"]
117
+ # Fallback a string
118
+ return json.dumps(data)[:4000]
119
+ else:
120
+ raise RuntimeError(f"HTTP {r.status_code}: {r.text[:1000]}")
121
+
122
+ def call_biomedlm_remote(prompt: str) -> (str, str):
123
+ """
124
+ Intenta usar InferenceClient.text_generation; si levanta StopIteration/otros,
125
+ cae a HTTP raw. Retorna (respuesta, debug_msg)
126
+ """
127
+ client = get_biomedlm()[1]
128
+ try:
129
+ out = client.text_generation(
130
+ prompt,
131
+ max_new_tokens=GEN_MAX_NEW_TOKENS,
132
+ temperature=GEN_TEMPERATURE,
133
+ top_p=GEN_TOP_P,
134
+ repetition_penalty=GEN_REP_PENALTY,
135
+ stop_sequences=STOP_SEQS,
136
+ details=False, # mantener string plano
137
+ stream=False,
138
+ timeout=GEN_TIMEOUT,
139
+ )
140
+ answer = out.strip() if isinstance(out, str) else str(out)
141
+ return answer, ""
142
+ except Exception as e:
143
+ # Fallback a HTTP
144
+ try:
145
+ answer = _hf_text_generation_raw(
146
+ BIO_MODEL_ID, prompt,
147
+ GEN_TEMPERATURE, GEN_TOP_P, GEN_REP_PENALTY,
148
+ GEN_MAX_NEW_TOKENS, STOP_SEQS, GEN_TIMEOUT
149
+ ).strip()
150
+ dbg = f"[Fallback HTTP HF] {e.__class__.__name__}: {str(e) or repr(e)}"
151
+ return answer, dbg
152
+ except Exception as e2:
153
+ raise RuntimeError(f"Remote generation failed: {e.__class__.__name__}: {e} | HTTP fallback: {e2.__class__.__name__}: {e2}")
154
+
155
  @spaces.GPU
156
  def biomedlm_infer_local(prompt: str,
157
  temperature=0.2,
 
195
  return "OK::" + text.strip()
196
 
197
  except Exception as e:
 
198
  err_cls = e.__class__.__name__
199
  return f"ERR::[{err_cls}] {str(e) or repr(e)}"
200
 
 
205
  user_msg = "Analiza el CONTEXTO_OCR anterior y responde a partir de ese contenido."
206
  prompt = build_prompt(chat_msgs, ocr_md, ocr_txt, user_msg)
207
 
208
+ mode, _handle = get_biomedlm()
209
 
210
+ # Preferido: remoto (evita límites ZeroGPU y CUDA en main)
211
  if mode == "remote":
212
+ answer, dbg = call_biomedlm_remote(prompt)
 
 
 
 
 
 
 
 
213
  updated = (chat_msgs or []) + [
214
  {"role": "user", "content": user_msg},
215
  {"role": "assistant", "content": answer}
216
  ]
217
+ return updated, "", gr.update(value=dbg)
218
 
219
  # Local (ZeroGPU)
220
  res = biomedlm_infer_local(
 
238
 
239
  # Fallback automático a remoto si está permitido
240
  if BIO_FALLBACK_REMOTE:
241
+ answer2, dbg2 = call_biomedlm_remote(prompt)
 
 
 
 
 
 
 
 
 
242
  updated = (chat_msgs or []) + [
243
  {"role": "user", "content": user_msg},
244
  {"role": "assistant", "content": answer2}
245
  ]
246
+ return updated, "", gr.update(value=f"[Local->Remoto fallback]\n{err_msg}\n{dbg2}")
 
247
  else:
248
  updated = (chat_msgs or []) + [
249
  {"role": "user", "content": user_msg},