jorgeiv500 commited on
Commit
0be85e9
·
verified ·
1 Parent(s): 81e18be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -33
app.py CHANGED
@@ -1,17 +1,18 @@
1
- # app.py — DeepSeek-OCR (GPU worker) + TxAgent-T1-Llama-3.1-8B (HF Inference)
2
- # -----------------------------------------------------------------------------
3
- # • OCR: DeepSeek-OCR cargado en CPU y movido a GPU SOLO dentro de @spaces.GPU.
4
- # • Chat: mims-harvard/TxAgent-T1-Llama-3.1-8B por InferenceClient (serverless), sin CUDA local.
5
- # • Variables recomendadas en Settings Secrets:
6
- # HF_TOKEN=hf_xxx (requerido para Inference)
 
7
  # TX_MODEL_ID=mims-harvard/TxAgent-T1-Llama-3.1-8B
8
- # TX_PROVIDER=hf-inference
9
  # GEN_MAX_NEW_TOKENS=512
10
  # GEN_TEMPERATURE=0.2
11
  # GEN_TOP_P=0.9
12
  # OCR_REVISION=<commit opcional para fijar versión estable>
13
  # OCR_ATTN_IMPL=flash_attention_2 (o "eager" si no hay FlashAttention)
14
- # -----------------------------------------------------------------------------
15
 
16
  import os, tempfile, traceback
17
  import gradio as gr
@@ -22,23 +23,26 @@ import spaces
22
  from huggingface_hub import InferenceClient
23
 
24
  # =========================
25
- # Chat remoto TxAgent (HF Inference)
26
  # =========================
27
- TX_MODEL_ID = os.getenv("TX_MODEL_ID", "mims-harvard/TxAgent-T1-Llama-3.1-8B")
28
- TX_PROVIDER = os.getenv("TX_PROVIDER", "hf-inference") # serverless en HF
29
- HF_TOKEN = os.getenv("HF_TOKEN") # requerido
30
 
31
  GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
32
  GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
33
  GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
34
 
35
- # Cliente remoto (timeout en el constructor; NO pases timeout al método)
36
- tx_client = InferenceClient(
37
- model=TX_MODEL_ID,
38
- provider=TX_PROVIDER,
39
- token=HF_TOKEN,
40
- timeout=60.0,
41
- )
 
 
 
42
 
43
  def _system_prompt():
44
  return (
@@ -59,19 +63,56 @@ def _mk_messages(ocr_md: str, ocr_txt: str, user_msg: str):
59
  ]
60
 
61
  def txagent_chat_remote(ocr_md: str, ocr_txt: str, user_msg: str) -> str:
 
 
 
 
 
62
  messages = _mk_messages(ocr_md, ocr_txt, user_msg)
63
- out = tx_client.chat.completions.create(
64
- model=TX_MODEL_ID,
65
- messages=messages,
66
- max_tokens=GEN_MAX_NEW_TOKENS,
67
- temperature=GEN_TEMPERATURE,
68
- top_p=GEN_TOP_P,
69
- stream=False,
70
  )
71
- return out.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # =========================
74
- # OCR — DeepSeek-OCR (Transformers), CUDA solo en worker
75
  # =========================
76
  def _best_dtype():
77
  if torch.cuda.is_available():
@@ -80,7 +121,7 @@ def _best_dtype():
80
 
81
  def _load_ocr_model():
82
  model_id = "deepseek-ai/DeepSeek-OCR"
83
- revision = os.getenv("OCR_REVISION", None) # <-- fija commit para estabilidad
84
  attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")
85
 
86
  tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, revision=revision)
@@ -94,7 +135,7 @@ def _load_ocr_model():
94
  ).eval()
95
  return tok, mdl
96
  except Exception as e:
97
- # Fallback si FlashAttention2 no está disponible
98
  if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
99
  mdl = AutoModel.from_pretrained(
100
  model_id,
@@ -108,7 +149,7 @@ def _load_ocr_model():
108
 
109
  OCR_TOKENIZER, OCR_MODEL = _load_ocr_model()
110
 
111
- @spaces.GPU # ← toca CUDA solo aquí
112
  def ocr_infer(image: Image.Image, model_size: str, task_type: str, is_eval_mode: bool):
113
  if image is None:
114
  return None, "Sube una imagen primero.", "Sube una imagen primero."
@@ -261,7 +302,7 @@ with gr.Blocks(title="OpScanIA — DeepSeek-OCR + TxAgent (HF Inference)", theme
261
  clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, error_box])
262
 
263
  if __name__ == "__main__":
264
- # Nota: en Gradio 5 no existe concurrency_count en queue()
265
- # Puedes lanzar directo, o usar queue(max_size=…)
266
  # demo.queue(max_size=32)
267
  demo.launch()
 
1
+ # app.py — DeepSeek-OCR (GPU worker) + TxAgent-T1-Llama-3.1-8B (HF Inference vía text_generation)
2
+ # -----------------------------------------------------------------------------------------------
3
+ # • OCR: DeepSeek-OCR cargado en CPU y movido a GPU SOLO dentro de @spaces.GPU (evita “CUDA en main”).
4
+ # • Chat: mims-harvard/TxAgent-T1-Llama-3.1-8B por InferenceClient.text_generation con provider Featherless AI.
5
+ # • Sin queue(concurrency_count): compatible con Gradio 5.
6
+ # • Variables recomendadas (Settings Secrets):
7
+ # HF_TOKEN=hf_xxx (requerido para Inference)
8
  # TX_MODEL_ID=mims-harvard/TxAgent-T1-Llama-3.1-8B
9
+ # TX_TOKENIZER_ID=mims-harvard/TxAgent-T1-Llama-3.1-8B
10
  # GEN_MAX_NEW_TOKENS=512
11
  # GEN_TEMPERATURE=0.2
12
  # GEN_TOP_P=0.9
13
  # OCR_REVISION=<commit opcional para fijar versión estable>
14
  # OCR_ATTN_IMPL=flash_attention_2 (o "eager" si no hay FlashAttention)
15
+ # -----------------------------------------------------------------------------------------------
16
 
17
  import os, tempfile, traceback
18
  import gradio as gr
 
23
  from huggingface_hub import InferenceClient
24
 
25
  # =========================
26
+ # Config — Chat remoto (TxAgent por text_generation + Featherless)
27
  # =========================
28
+ TX_MODEL_ID = os.getenv("TX_MODEL_ID", "mims-harvard/TxAgent-T1-Llama-3.1-8B")
29
+ TX_TOKENIZER_ID = os.getenv("TX_TOKENIZER_ID", TX_MODEL_ID)
30
+ HF_TOKEN = os.getenv("HF_TOKEN") # requerido
31
 
32
  GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
33
  GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
34
  GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
35
 
36
+ # Cliente genérico (sin atar modelo/proveedor; se pasa en cada llamada)
37
+ _hf_client = InferenceClient(token=HF_TOKEN, timeout=60.0)
38
+
39
+ # Tokenizer para aplicar chat template → prompt
40
+ _TX_TOKENIZER = None
41
+ def get_tx_tokenizer():
42
+ global _TX_TOKENIZER
43
+ if _TX_TOKENIZER is None:
44
+ _TX_TOKENIZER = AutoTokenizer.from_pretrained(TX_TOKENIZER_ID, trust_remote_code=True)
45
+ return _TX_TOKENIZER
46
 
47
  def _system_prompt():
48
  return (
 
63
  ]
64
 
65
  def txagent_chat_remote(ocr_md: str, ocr_txt: str, user_msg: str) -> str:
66
+ """
67
+ Usa text_generation con provider Featherless AI.
68
+ - Convertimos mensajes a prompt con el chat template del tokenizer.
69
+ - Llamamos al router con model='mims-harvard/TxAgent…:featherless-ai'
70
+ """
71
  messages = _mk_messages(ocr_md, ocr_txt, user_msg)
72
+ tok = get_tx_tokenizer()
73
+ prompt = tok.apply_chat_template(
74
+ messages,
75
+ tokenize=False,
76
+ add_generation_prompt=True, # deja el turno del assistant abierto
 
 
77
  )
78
+
79
+ model_with_provider = f"{TX_MODEL_ID}:featherless-ai"
80
+ try:
81
+ out = _hf_client.text_generation(
82
+ model=model_with_provider,
83
+ prompt=prompt,
84
+ max_new_tokens=GEN_MAX_NEW_TOKENS,
85
+ temperature=GEN_TEMPERATURE,
86
+ top_p=GEN_TOP_P,
87
+ stream=False,
88
+ )
89
+ # En huggingface_hub nuevas, text_generation devuelve str (texto generado).
90
+ return out if isinstance(out, str) else str(out)
91
+ except Exception as e1:
92
+ # Fallback: crear cliente amarrado al provider por si el mapping cambia
93
+ try:
94
+ client_fb = InferenceClient(
95
+ model=TX_MODEL_ID,
96
+ provider="featherless-ai",
97
+ token=HF_TOKEN,
98
+ timeout=60.0,
99
+ )
100
+ out = client_fb.text_generation(
101
+ prompt=prompt,
102
+ max_new_tokens=GEN_MAX_NEW_TOKENS,
103
+ temperature=GEN_TEMPERATURE,
104
+ top_p=GEN_TOP_P,
105
+ stream=False,
106
+ )
107
+ return out if isinstance(out, str) else str(out)
108
+ except Exception as e2:
109
+ raise RuntimeError(
110
+ f"Remote generation failed: {e1.__class__.__name__}: {e1} | "
111
+ f"Fallback: {e2.__class__.__name__}: {e2}"
112
+ )
113
 
114
  # =========================
115
+ # OCR — DeepSeek-OCR (Transformers), CUDA solo en worker GPU
116
  # =========================
117
  def _best_dtype():
118
  if torch.cuda.is_available():
 
121
 
122
  def _load_ocr_model():
123
  model_id = "deepseek-ai/DeepSeek-OCR"
124
+ revision = os.getenv("OCR_REVISION", None) # fija un commit para estabilidad si quieres
125
  attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")
126
 
127
  tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, revision=revision)
 
135
  ).eval()
136
  return tok, mdl
137
  except Exception as e:
138
+ # Fallback si FA2 no está disponible en el entorno
139
  if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
140
  mdl = AutoModel.from_pretrained(
141
  model_id,
 
149
 
150
  OCR_TOKENIZER, OCR_MODEL = _load_ocr_model()
151
 
152
+ @spaces.GPU # ← toca CUDA solo aquí, no en el proceso principal
153
  def ocr_infer(image: Image.Image, model_size: str, task_type: str, is_eval_mode: bool):
154
  if image is None:
155
  return None, "Sube una imagen primero.", "Sube una imagen primero."
 
302
  clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, error_box])
303
 
304
  if __name__ == "__main__":
305
+ # En Gradio 5 ya no existe concurrency_count en queue()
306
+ # Lanza directo (o usa demo.queue(max_size=…) si quisieras limitar cola).
307
  # demo.queue(max_size=32)
308
  demo.launch()