VeuReu commited on
Commit
0fb6f95
verified
1 Parent(s): 7b4bcba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py CHANGED
@@ -55,6 +55,49 @@ def _build_prompt(prompt: str, system: Optional[str]) -> str:
55
  sys_part = (f"<<SYS>>\n{system.strip()}\n<</SYS>>\n\n" if system and system.strip() else "")
56
  return sys_part + f"### Instrucci贸n\n{prompt}\n\n### Respuesta\n"
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  @spaces.GPU # usa GPU si est谩 disponible (ZeroGPU)
59
  def _generate(
60
  prompt: str,
@@ -113,4 +156,13 @@ with gr.Blocks(title="Salamandra 7B Instruct 路 ZeroGPU") as demo:
113
  out_engine = gr.Textbox(label="Respuesta (ENGINE)")
114
  gr.Button("Probar /predict").click(predict_for_engine, [in_prompt_engine], out_engine, api_name="predict", concurrency_limit=1)
115
 
 
 
 
 
 
 
 
 
 
116
  demo.queue(max_size=16).launch()
 
55
  sys_part = (f"<<SYS>>\n{system.strip()}\n<</SYS>>\n\n" if system and system.strip() else "")
56
  return sys_part + f"### Instrucci贸n\n{prompt}\n\n### Respuesta\n"
57
 
58
+ @spaces.GPU # usa GPU si est谩 disponible (ZeroGPU)
59
+ def _generate_with_tools(
60
+ messages: List[Dict[str, str]],
61
+ tools: List[Dict[str, Any]],
62
+ max_new_tokens: int = 512,
63
+ temperature: float = 0.7,
64
+ top_p: float = 0.95,
65
+ ) -> Dict[str, Any]:
66
+ tok, model = _lazy_load()
67
+ tools_md = _render_tools_md(tools)
68
+ prompt = _compose_chat_prompt(messages, tools_md)
69
+
70
+ inputs = tok(prompt, return_tensors="pt").to(DEVICE)
71
+ with torch.inference_mode():
72
+ out = model.generate(
73
+ **inputs,
74
+ max_new_tokens=int(max_new_tokens),
75
+ temperature=float(temperature),
76
+ top_p=float(top_p),
77
+ do_sample=True if temperature > 0 else False,
78
+ pad_token_id=tok.eos_token_id,
79
+ eos_token_id=tok.eos_token_id,
80
+ )
81
+ text = tok.decode(out[0], skip_special_tokens=True).strip()
82
+
83
+ # Si el modelo devuelve un bloque JSON con 'tool_calls', lo intentamos extraer.
84
+ tool_calls: List[Dict[str, Any]] = []
85
+ try:
86
+ # busca el 煤ltimo {...} que contenga "tool_calls"
87
+ matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S))
88
+ if matches:
89
+ block = text[matches[-1].start():matches[-1].end()]
90
+ obj = json.loads(block)
91
+ tc = obj.get("tool_calls", [])
92
+ if isinstance(tc, list):
93
+ tool_calls = tc
94
+ except Exception:
95
+ pass
96
+
97
+ tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else []
98
+
99
+ return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results}
100
+
101
  @spaces.GPU # usa GPU si est谩 disponible (ZeroGPU)
102
  def _generate(
103
  prompt: str,
 
156
  out_engine = gr.Textbox(label="Respuesta (ENGINE)")
157
  gr.Button("Probar /predict").click(predict_for_engine, [in_prompt_engine], out_engine, api_name="predict", concurrency_limit=1)
158
 
159
+ with gr.Row():
160
+ prompt = gr.Textbox(label="prompt", lines=10)
161
+ with gr.Row():
162
+ btn2 = gr.Button("Generar", variant="primary")
163
+ with gr.Row():
164
+ out2 = gr.JSON(label="Salida")
165
+
166
+ btn2.click(salamandra_chat_endpoint, [prompt], out2, api_name="generate_out_from_prompt", concurrency_limit=1)
167
+
168
  demo.queue(max_size=16).launch()