Marylene commited on
Commit
b6dc1cd
·
verified ·
1 Parent(s): 5ede9de

Update quick_deploy_agent.py

Browse files
Files changed (1) hide show
  1. quick_deploy_agent.py +38 -20
quick_deploy_agent.py CHANGED
@@ -5,11 +5,15 @@ import requests
5
  from smolagents import Tool, CodeAgent, InferenceClientModel
6
  from sentence_transformers import SentenceTransformer, util
7
 
 
 
 
 
 
8
  FALLBACK_MODELS = [
9
- # ordre de préférence ; tous dispos en Inference API publique
10
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
11
- "mistralai/Mistral-7B-Instruct-v0.3",
12
  "Qwen/Qwen2.5-7B-Instruct",
 
13
  ]
14
 
15
 
@@ -633,33 +637,46 @@ class Resolve(Tool):
633
 
634
  # ---- build_agent ----
635
  def build_agent(model_id: str | None = None) -> CodeAgent:
636
- mid = model_id or os.getenv("HF_MODEL_ID") or "meta-llama/Meta-Llama-3.1-8B-Instruct"
637
  model = InferenceClientModel(
638
  model_id=mid,
639
  temperature=0.2,
640
- max_tokens=512, # chat_completion param
641
- timeout=60,
642
  top_p=0.95,
643
  )
644
  agent = CodeAgent(
645
- tools=[
646
- ValidateEANTool(),
647
- OFFByEAN(),
648
- RegexCOICOP(),
649
- OFFtoCOICOP(),
650
- SemSim(),
651
- WebSearch(), # <-- autorise recherche web
652
- WebGet(), # <-- autorise lecture de pages
653
- MergeCandidatesTool(),
654
- Resolve(),
655
- ],
656
  model=model,
657
  add_base_tools=False,
658
- max_steps=8, # un peu plus de marge si web utilisé
659
- verbosity_level=2,
660
  )
661
  return agent
662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  def parse_result(res):
664
  if isinstance(res, dict): return res
665
  try: return ast.literal_eval(res)
@@ -704,5 +721,6 @@ if __name__ == "__main__":
704
 
705
  Retourne uniquement un JSON valide (objet), sans backticks.
706
  """
707
- out = agent.run(task)
 
708
  print(parse_result(out))
 
5
  from smolagents import Tool, CodeAgent, InferenceClientModel
6
  from sentence_transformers import SentenceTransformer, util
7
 
8
+ # --- Config runtime via env (avec valeurs par défaut sûres sur Space) ---
9
+ HF_TIMEOUT = int(os.getenv("HF_TIMEOUT", "180")) # 180s au lieu de 60s
10
+ HF_MAX_TOKENS = int(os.getenv("HF_MAX_TOKENS", "384")) # réduire un peu la génération
11
+ AGENT_MAX_STEPS = int(os.getenv("AGENT_MAX_STEPS", "6"))
12
+ # Ordre: un modèle préféré, puis 2 replis rapides et dispo publique
13
  FALLBACK_MODELS = [
14
+ os.getenv("HF_MODEL_ID") or "meta-llama/Meta-Llama-3.1-8B-Instruct",
 
 
15
  "Qwen/Qwen2.5-7B-Instruct",
16
+ "HuggingFaceH4/zephyr-7b-beta",
17
  ]
18
 
19
 
 
637
 
638
  # ---- build_agent ----
639
  def build_agent(model_id: str | None = None) -> CodeAgent:
640
+ mid = model_id or FALLBACK_MODELS[0]
641
  model = InferenceClientModel(
642
  model_id=mid,
643
  temperature=0.2,
644
+ max_tokens=HF_MAX_TOKENS,
645
+ timeout=HF_TIMEOUT, # ⬅️ timeout augmenté
646
  top_p=0.95,
647
  )
648
  agent = CodeAgent(
649
+ tools=[ValidateEANTool(), OFFByEAN(), RegexCOICOP(), OFFtoCOICOP(), SemSim(),
650
+ WebSearch(), WebGet(),
651
+ MergeCandidatesTool(), Resolve()],
 
 
 
 
 
 
 
 
652
  model=model,
653
  add_base_tools=False,
654
+ max_steps=AGENT_MAX_STEPS, # ⬅️ moins d’étapes = moins de tokens/latence
655
+ verbosity_level=1, # ⬅️ logs plus courts = moins de tokens sortants
656
  )
657
  return agent
658
 
659
+ # ---- run task with fallback ----
660
+ def run_task_with_fallback(task: str):
661
+ errors = []
662
+ for mid in [m for m in FALLBACK_MODELS if m]:
663
+ try:
664
+ agent = build_agent(mid)
665
+ return agent.run(task)
666
+ except Exception as e:
667
+ errors.append(f"{mid}: {type(e).__name__}: {e}")
668
+ # on tente le modèle suivant
669
+ continue
670
+ # Si TOUT a échoué, renvoyer un JSON propre plutôt qu’un crash
671
+ return {
672
+ "final": None,
673
+ "alternatives": [],
674
+ "candidates_top": [],
675
+ "explanation": "LLM backend indisponible (timeouts).",
676
+ "errors": errors,
677
+ }
678
+
679
+
680
  def parse_result(res):
681
  if isinstance(res, dict): return res
682
  try: return ast.literal_eval(res)
 
721
 
722
  Retourne uniquement un JSON valide (objet), sans backticks.
723
  """
724
+ # out = agent.run(task)
725
+ out = run_task_with_fallback(task)
726
  print(parse_result(out))