| |
| """ |
| mac-tensor agent β Interactive agentic REPL backed by distributed expert nodes. |
| |
| The model can call tools by emitting XML tags. We use STOP SEQUENCES so |
| generation halts the moment a closing tag appears, parse the tool call, |
| execute it, then feed the result back into the SAME KV cache (no reset). |
| |
| Tools: |
| <read>path</read> β read a file |
| <ls>path</ls> β list directory contents |
| <shell>command</shell> β run a read-only shell command |
| <search>query</search> β DuckDuckGo web search |
| <python>expr</python> β restricted python eval |
| <write path="...">content</write> β write a file (requires --write) |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import re |
| import subprocess |
| import sys |
| import urllib.parse |
| import urllib.request |
| from collections import deque |
|
|
| import numpy as np |
|
|
|
|
| |
| |
| |
|
|
| SYSTEM_PROMPT = """You are a coding agent on an Apple Silicon cluster. You answer the user by either: |
| (a) calling ONE tool and stopping, or |
| (b) giving a short final answer with no tags. |
| |
| Rules β these are critical: |
| 1. NEVER write a <result> tag yourself. The system inserts results. |
| 2. After emitting a tool call, STOP. Do not write anything else. |
| 3. Use exactly ONE tool per step. The system runs it and shows you the result. |
| 4. When you have enough information, give a SHORT final answer (1-3 sentences). |
| 5. Do not repeat yourself. Do not loop. |
| |
| Tools: |
| <read>path</read> Read a file (use ~ for home). |
| <ls>path</ls> List a directory. |
| <shell>cmd</shell> Run a read-only shell command. |
| <search>query</search> Web search (DuckDuckGo). |
| <python>expr</python> Eval a Python expression. |
| |
| Example interaction: |
| |
| User: What's in the README? |
| You: <read>README.md</read> |
| [system inserts: <result># My Project\\nA tool for X.</result>] |
| You: The README describes a tool for X. |
| |
| Now the user asks: |
| """ |
|
|
|
|
| VISION_SYSTEM_PROMPT = """You are a vision agent on an Apple Silicon cluster. You can see the image the user uploaded, and you have a precision instrument called Falcon Perception that finds exact pixel locations of objects. |
| |
| Rules: |
| 1. NEVER write a <result> tag yourself. The system inserts results. |
| 2. After emitting a tool call, STOP. Wait for the result. |
| 3. Use ONE tool per step. |
| 4. When you have enough information, give a SHORT final answer (1-3 sentences) with no tags. |
| 5. Trust the numbers the tools give you over your own visual estimates. |
| |
| When to use tools vs answer directly: |
| - If the user just wants a description of the image, answer directly. No tools needed. |
| - If the user asks "how many" or "where exactly" or "which is largest/closest/leftmost", use grounding tools. |
| - If the user asks for spatial relationships (closest, biggest, leftmost), first ground the objects, then use a spatial tool. |
| |
| Tools: |
| |
| <ground>object name or short description</ground> |
| Find all instances of an object in the image. Returns a list of masks |
| with id, centroid (x, y in 0-1), area_fraction, region (top-left, center, etc.). |
| Example: <ground>bird</ground> |
| |
| <extreme slot="object_name" direction="bottommost"/> |
| After grounding, find the most extreme mask in a slot. |
| direction can be: topmost, bottommost, leftmost, rightmost, largest, smallest |
| In a 2D photo, "closest to camera" usually means bottommost (lowest in frame). |
| Example: <extreme slot="bird" direction="bottommost"/> |
| |
| <count slot="object_name"/> |
| Return the number of masks in a slot. |
| Example: <count slot="bird"/> |
| |
| Example interaction (single ground): |
| User: How many cats are in this image? |
| You: <ground>cat</ground> |
| [system: <result>{"slot":"cat","count":3,"masks":[{"id":1,...},{"id":2,...},{"id":3,...}]}</result>] |
| You: There are 3 cats in the image. |
| |
| Example interaction (chained): |
| User: Which bird is biggest? |
| You: <ground>bird</ground> |
| [system: <result>{"slot":"bird","count":4,"masks":[...]}</result>] |
| You: <extreme slot="bird" direction="largest"/> |
| [system: <result>{"winner":{"id":3,"area_fraction":0.18,...}}</result>] |
| You: The largest bird is mask #3, taking up about 18% of the image. |
| |
| Now the user asks: |
| """ |
|
|
|
|
| |
| |
| |
|
|
|
|
| def tool_read(arg): |
| try: |
| path = os.path.expanduser(arg.strip()) |
| with open(path) as f: |
| content = f.read() |
| if len(content) > 4000: |
| content = content[:4000] + f"\n\n[truncated, {len(content) - 4000} more chars]" |
| return content |
| except Exception as e: |
| return f"Error reading {arg}: {e}" |
|
|
|
|
| def tool_ls(arg): |
| try: |
| path = os.path.expanduser(arg.strip()) |
| items = sorted(os.listdir(path)) |
| out = [] |
| for item in items[:50]: |
| full = os.path.join(path, item) |
| out.append(f" {item}{'/' if os.path.isdir(full) else ''}") |
| if len(items) > 50: |
| out.append(f" ... ({len(items) - 50} more)") |
| return "\n".join(out) if out else "(empty)" |
| except Exception as e: |
| return f"Error: {e}" |
|
|
|
|
| def tool_shell(arg, allow_write=False): |
| cmd = arg.strip() |
| if not allow_write: |
| DANGEROUS = ["rm ", "mv ", "dd ", "mkfs", ">", ">>", "chmod", |
| "chown", "sudo", "kill", "shutdown", "format", "fdisk"] |
| for d in DANGEROUS: |
| if d in cmd.lower(): |
| return f"Refused (destructive: '{d}')" |
| try: |
| result = subprocess.run(cmd, shell=True, capture_output=True, |
| text=True, timeout=15) |
| out = (result.stdout or "").strip() |
| err = (result.stderr or "").strip() |
| if not out and not err: |
| return f"(exit {result.returncode}, no output)" |
| parts = [] |
| if out: |
| parts.append(out[:2000] + ("..." if len(out) > 2000 else "")) |
| if err: |
| parts.append(f"[stderr] {err[:500]}") |
| if result.returncode != 0: |
| parts.append(f"[exit {result.returncode}]") |
| return "\n".join(parts) |
| except subprocess.TimeoutExpired: |
| return "Error: timeout (15s)" |
| except Exception as e: |
| return f"Error: {e}" |
|
|
|
|
| def tool_search(arg): |
| query = arg.strip() |
| try: |
| url = f"https://html.duckduckgo.com/html/?q={urllib.parse.quote(query)}" |
| req = urllib.request.Request( |
| url, |
| headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " |
| "AppleWebKit/537.36"}, |
| ) |
| with urllib.request.urlopen(req, timeout=10) as resp: |
| html = resp.read().decode("utf-8", errors="ignore") |
| results = [] |
| pattern = re.compile( |
| r'class="result__a"[^>]*>([^<]+)</a>.*?' |
| r'class="result__snippet"[^>]*>(.+?)</a>', |
| re.DOTALL, |
| ) |
| for m in list(pattern.finditer(html))[:4]: |
| title = re.sub(r"\s+", " ", m.group(1).strip()) |
| snippet = re.sub(r"<[^>]+>", "", m.group(2)) |
| snippet = re.sub(r"\s+", " ", snippet).strip()[:160] |
| results.append(f"β’ {title}\n {snippet}") |
| return "\n\n".join(results) if results else "No results" |
| except Exception as e: |
| return f"Search failed: {e}" |
|
|
|
|
| def tool_python(arg): |
| try: |
| safe = { |
| "abs": abs, "all": all, "any": any, "bool": bool, "chr": chr, |
| "dict": dict, "divmod": divmod, "enumerate": enumerate, "filter": filter, |
| "float": float, "hex": hex, "int": int, "len": len, "list": list, |
| "map": map, "max": max, "min": min, "ord": ord, "pow": pow, |
| "range": range, "reversed": reversed, "round": round, "set": set, |
| "sorted": sorted, "str": str, "sum": sum, "tuple": tuple, "zip": zip, |
| } |
| return str(eval(arg.strip(), {"__builtins__": safe}, {})) |
| except Exception as e: |
| return f"Error: {e}" |
|
|
|
|
| def tool_write(path_arg, content, allow_write=False): |
| if not allow_write: |
| return "Refused. Pass --write to enable." |
| try: |
| path = os.path.expanduser(path_arg.strip()) |
| os.makedirs(os.path.dirname(path) or ".", exist_ok=True) |
| with open(path, "w") as f: |
| f.write(content) |
| return f"Wrote {len(content)} bytes to {path}" |
| except Exception as e: |
| return f"Error: {e}" |
|
|
|
|
| |
| |
| |
|
|
| |
| TOOL_REGEX = re.compile( |
| r"<(read|ls|shell|search|python|ground|count)>(.+?)</\1>", |
| re.DOTALL, |
| ) |
| WRITE_REGEX = re.compile( |
| r'<write\s+path="([^"]+)">(.+?)</write>', |
| re.DOTALL, |
| ) |
| |
| EXTREME_REGEX = re.compile( |
| r'<extreme\s+slot="([^"]+)"\s+direction="([^"]+)"\s*/>', |
| ) |
| COUNT_ATTR_REGEX = re.compile( |
| r'<count\s+slot="([^"]+)"\s*/>', |
| ) |
| BBOX_REGEX = re.compile( |
| r'<bbox\s+id="?(\d+)"?\s*/>', |
| ) |
|
|
| |
| STOP_SEQUENCES = [ |
| "</read>", "</ls>", "</shell>", "</search>", "</python>", "</write>", |
| "</ground>", "</count>", |
| |
| "/>", |
| ] |
|
|
|
|
| VISION_STOP_SEQUENCES = [ |
| "</ground>", "</count>", |
| "/>", |
| "</result>", |
| ] |
|
|
|
|
| def parse_first_tool(text): |
| """Find the first complete tool call. Returns dict or None.""" |
| |
| m = WRITE_REGEX.search(text) |
| if m: |
| return {"tool": "write", "path": m.group(1), "content": m.group(2), |
| "start": m.start(), "end": m.end()} |
| |
| m = EXTREME_REGEX.search(text) |
| if m: |
| return {"tool": "extreme", "slot": m.group(1), "direction": m.group(2), |
| "content": "", "start": m.start(), "end": m.end()} |
| m = COUNT_ATTR_REGEX.search(text) |
| if m: |
| return {"tool": "count", "slot": m.group(1), |
| "content": "", "start": m.start(), "end": m.end()} |
| m = BBOX_REGEX.search(text) |
| if m: |
| return {"tool": "bbox", "mask_id": int(m.group(1)), |
| "content": "", "start": m.start(), "end": m.end()} |
| |
| m = TOOL_REGEX.search(text) |
| if m: |
| return {"tool": m.group(1), "content": m.group(2), |
| "start": m.start(), "end": m.end()} |
| return None |
|
|
|
|
| def execute_tool(call, allow_write=False, falcon_tools=None): |
| t = call["tool"] |
| if t == "read": return tool_read(call["content"]) |
| if t == "ls": return tool_ls(call["content"]) |
| if t == "shell": return tool_shell(call["content"], allow_write=allow_write) |
| if t == "search": return tool_search(call["content"]) |
| if t == "python": return tool_python(call["content"]) |
| if t == "write": return tool_write(call["path"], call["content"], allow_write=allow_write) |
| |
| if falcon_tools is None: |
| return f"Vision tool '{t}' requires Falcon Perception (--falcon)" |
| if t == "ground": |
| return falcon_tools.ground(call["content"].strip()) |
| if t == "extreme": |
| return falcon_tools.extreme(call["slot"], call["direction"]) |
| if t == "count": |
| |
| slot = call.get("slot") or call.get("content", "").strip() |
| return falcon_tools.count_slot(slot) |
| if t == "bbox": |
| return falcon_tools.bbox(call["mask_id"]) |
| return f"Unknown tool: {t}" |
|
|
|
|
| |
| |
| |
|
|
|
|
| class AgentBackend: |
| """Wraps the distributed engine with token-level generation, stop sequences, |
| and repetition penalty. Maintains a SINGLE KV cache across the whole turn β |
| we never reset it between tool calls, we just append the result tokens. |
| """ |
|
|
| def __init__(self, model_key, node_urls): |
| self.model_key = model_key |
| self.node_urls = node_urls |
| self.engine = None |
|
|
| def load(self): |
| script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") |
| sys.path.insert(0, script_dir) |
|
|
| if self.model_key == "gemma4": |
| cli_agent_src = os.path.expanduser("~/cli-agent/src") |
| if os.path.exists(cli_agent_src): |
| sys.path.insert(0, cli_agent_src) |
| from gemma4_distributed import Gemma4DistributedEngine |
| self.engine = Gemma4DistributedEngine(node_urls=self.node_urls) |
| else: |
| from distributed_interactive import InteractiveDistributedEngine |
| self.engine = InteractiveDistributedEngine(node_urls=self.node_urls) |
|
|
| self.engine.load() |
|
|
| def reset(self): |
| self.engine.reset_cache() |
|
|
| def encode(self, text): |
| if self.model_key == "gemma4": |
| return self.engine.encode(text) |
| |
| return self.engine.tokenizer.encode(text) |
|
|
| def decode(self, ids): |
| if self.model_key == "gemma4": |
| return self.engine.decode(ids) |
| return self.engine.tokenizer.decode(ids, skip_special_tokens=False) |
|
|
| def encode_chat_prefix(self, system_and_user_text): |
| """Wrap the prompt in the model's chat template.""" |
| if self.model_key == "gemma4": |
| return self.engine.encode_chat(system_and_user_text) |
| |
| try: |
| messages = [{"role": "user", "content": system_and_user_text}] |
| text = self.engine.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True, |
| ) |
| return self.engine.tokenizer.encode(text) |
| except Exception: |
| return self.encode(system_and_user_text) |
|
|
| def feed_tokens(self, token_ids): |
| """Run forward on a batch of new tokens. Returns final logits. |
| |
| Token IDs are appended to the existing KV cache β no reset. |
| """ |
| import mlx.core as mx |
| ids = mx.array([token_ids]) |
| logits = self.engine.forward(ids) |
| mx.eval(logits) |
| return logits |
|
|
| def generate_until_stop(self, prompt_text, stop_sequences, |
| max_tokens=400, temperature=0.5, |
| repetition_penalty=1.15, repetition_window=64, |
| on_chunk=None): |
| """Generate tokens until any stop sequence appears or max_tokens hit. |
| |
| Returns the generated text (excluding the stop sequence itself). |
| |
| Uses persistent KV cache: if you call this multiple times in a row, |
| each call appends to the same cache. |
| """ |
| import mlx.core as mx |
| import numpy as np |
|
|
| |
| prompt_ids = self.encode_chat_prefix(prompt_text) |
| logits = self.feed_tokens(prompt_ids) |
|
|
| generated_ids = [] |
| recent = deque(maxlen=repetition_window) |
| decoded_buffer = "" |
|
|
| |
| if self.model_key == "gemma4": |
| eos_set = {1, 106} |
| else: |
| eos_set = {248044, 248046} |
|
|
| for step in range(max_tokens): |
| |
| last_logits = logits[0, -1] |
|
|
| |
| if recent and repetition_penalty != 1.0: |
| last_np = np.array(last_logits.astype(mx.float32)) |
| for tid in set(recent): |
| if last_np[tid] > 0: |
| last_np[tid] /= repetition_penalty |
| else: |
| last_np[tid] *= repetition_penalty |
| last_logits = mx.array(last_np) |
|
|
| |
| if temperature <= 0: |
| next_token = int(mx.argmax(last_logits).item()) |
| else: |
| probs = mx.softmax(last_logits / temperature, axis=-1) |
| next_token = int(mx.random.categorical(mx.log(probs + 1e-10)).item()) |
|
|
| generated_ids.append(next_token) |
| recent.append(next_token) |
|
|
| if next_token in eos_set: |
| break |
|
|
| |
| new_decoded = self.decode(generated_ids) |
| new_chunk = new_decoded[len(decoded_buffer):] |
| decoded_buffer = new_decoded |
|
|
| if on_chunk and new_chunk: |
| on_chunk(new_chunk) |
|
|
| |
| stop_hit = None |
| for s in stop_sequences: |
| if s in decoded_buffer: |
| stop_hit = s |
| break |
| if stop_hit: |
| break |
|
|
| |
| logits = self.feed_tokens([next_token]) |
|
|
| return decoded_buffer |
|
|
|
|
| |
| |
| |
|
|
|
|
| class VisionAgentBackend: |
| """Wraps VisionGemma4Sniper + FalconPerceptionTools for chained tool use.""" |
|
|
| def __init__(self, vision_engine, falcon_tools=None): |
| self.vision_engine = vision_engine |
| self.falcon_tools = falcon_tools |
|
|
| def reset(self): |
| self.vision_engine.sniper.reset_cache() |
|
|
| def encode(self, text): |
| return self.vision_engine.sniper.tokenizer.encode(text).ids |
|
|
| def decode(self, ids): |
| return self.vision_engine.sniper.tokenizer.decode(ids) |
|
|
| def feed_tokens_text(self, token_ids): |
| """Feed text-only tokens after the initial vision prefill.""" |
| import mlx.core as mx |
| ids = mx.array([token_ids]) |
| logits = self.vision_engine.sniper.forward(ids) |
| mx.eval(logits) |
| return logits |
|
|
| def run_vision_prefill(self, prompt_text, image_path): |
| """Run the initial vision-aware forward pass and return logits.""" |
| if self.falcon_tools and image_path: |
| self.falcon_tools.set_image(image_path) |
| tokens, image_features, _ = self.vision_engine.encode_chat(prompt_text, image_path) |
| logits = self.vision_engine._prefill_logits(tokens, image_features) |
| return logits, len(tokens) |
|
|
|
|
| def _patch_vision_engine_prefill(): |
| """Add a _prefill_logits method to VisionGemma4Sniper. |
| |
| The original _prefill returns the first sampled token. We need a version |
| that returns the full logits so the agent can apply repetition penalty. |
| """ |
| import mlx.core as mx |
| try: |
| from .vision_engine import VisionGemma4Sniper, IMAGE_TOKEN_ID |
| except Exception: |
| return |
|
|
| if hasattr(VisionGemma4Sniper, "_prefill_logits"): |
| return |
|
|
| def _prefill_logits(self, input_token_ids, image_features): |
| import numpy as np |
| from mlx_lm.models.base import create_attention_mask |
| from mlx_vlm.models.gemma4.gemma4 import masked_scatter |
| from moe_agent_gemma4 import run_expert_ffn |
|
|
| self.sniper.reset_cache() |
| input_ids = mx.array([input_token_ids]) |
| h = self.sniper.model.model.embed_tokens(input_ids) |
| h = h * (self.sniper.model.args.hidden_size ** 0.5) |
|
|
| if image_features is not None: |
| image_mask = (input_ids == IMAGE_TOKEN_ID) |
| image_feats_flat = image_features.reshape(-1, image_features.shape[-1]) |
| image_feats_flat = image_feats_flat.astype(h.dtype) |
| image_mask_expanded = mx.expand_dims(image_mask, -1) |
| image_mask_expanded = mx.broadcast_to(image_mask_expanded, h.shape) |
| h = masked_scatter(h, image_mask_expanded, image_feats_flat) |
|
|
| mask = create_attention_mask(h, self.sniper.cache[0] if self.sniper.cache else None) |
|
|
| for i in range(self.sniper.num_layers): |
| layer = self.sniper.model.model.layers[i] |
| cache_i = self.sniper.cache[i] if self.sniper.cache else None |
| residual = h |
| h_norm = layer.input_layernorm(h) |
| h_attn = layer.self_attn(h_norm, mask=mask, cache=cache_i) |
| h_attn = layer.post_attention_layernorm(h_attn) |
| h = residual + h_attn |
| mx.eval(h) |
|
|
| residual = h |
| h_ff = layer.pre_feedforward_layernorm(h) |
| h_ff = layer.mlp(h_ff) |
|
|
| if layer.enable_moe_block: |
| h_dense = layer.post_feedforward_layernorm_1(h_ff) |
| B, L, D = residual.shape |
| residual_flat = residual.reshape(-1, D) |
| router = layer.router |
| x_normed = router._inline_rms_norm(residual_flat) |
| x_normed = x_normed * router.scale * (router.hidden_size ** -0.5) |
| scores = router.proj(x_normed) |
| probs = mx.softmax(scores, axis=-1) |
| top_k_indices = mx.argpartition(-probs, kth=router.top_k - 1, axis=-1)[..., :router.top_k] |
| top_k_weights = mx.take_along_axis(probs, top_k_indices, axis=-1) |
| top_k_weights = top_k_weights / mx.sum(top_k_weights, axis=-1, keepdims=True) |
| expert_scales = router.per_expert_scale[top_k_indices] |
| top_k_weights = top_k_weights * expert_scales |
| moe_input = layer.pre_feedforward_layernorm_2(residual_flat) |
| mx.eval(moe_input, top_k_indices, top_k_weights) |
| top_k_indices_r = top_k_indices.reshape(B, L, -1) |
| top_k_weights_r = top_k_weights.reshape(B, L, -1) |
| active_ids = list(set(int(e) for e in np.array(top_k_indices_r).flatten())) |
| expert_data = self.sniper.reader.get_experts(i, active_ids) |
| moe_input_r = moe_input.reshape(B, L, D) |
| expert_out = run_expert_ffn(moe_input_r, expert_data, top_k_indices_r, top_k_weights_r) |
| h_moe = layer.post_feedforward_layernorm_2(expert_out) |
| h_ff = h_dense + h_moe |
|
|
| h_ff = layer.post_feedforward_layernorm(h_ff) |
| h = residual + h_ff |
| h = h * layer.layer_scalar |
| mx.eval(h) |
| mx.clear_cache() |
|
|
| h = self.sniper.model.model.norm(h) |
| if self.sniper.model.args.tie_word_embeddings: |
| logits = self.sniper.model.model.embed_tokens.as_linear(h) |
| else: |
| logits = self.sniper.model.lm_head(h) |
| mx.eval(logits) |
| return logits |
|
|
| VisionGemma4Sniper._prefill_logits = _prefill_logits |
|
|
|
|
| _patch_vision_engine_prefill() |
|
|
|
|
| def run_vision_agent_turn_stream(vision_backend, user_question, image_path, |
| max_iterations=4, max_tokens=300, |
| temperature=0.5, allow_write=False): |
| """Run one user turn with vision + Falcon tool use.""" |
| import mlx.core as mx |
| import json |
| from collections import deque |
|
|
| vision_backend.reset() |
|
|
| prompt = VISION_SYSTEM_PROMPT + " " + user_question |
|
|
| yield {"type": "step_start", "step": 1, "max": max_iterations} |
|
|
| try: |
| logits, _ = vision_backend.run_vision_prefill(prompt, image_path) |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| yield {"type": "error", "message": f"vision prefill failed: {e}"} |
| return |
|
|
| last_logits = logits[0, -1] |
| if temperature <= 0: |
| next_token = int(mx.argmax(last_logits).item()) |
| else: |
| probs = mx.softmax(last_logits / temperature, axis=-1) |
| next_token = int(mx.random.categorical(mx.log(probs + 1e-10)).item()) |
|
|
| eos_set = {1, 106} |
|
|
| for iteration in range(max_iterations): |
| if iteration > 0: |
| yield {"type": "step_start", "step": iteration + 1, "max": max_iterations} |
|
|
| generated = [next_token] |
| recent = deque([next_token], maxlen=64) |
| decoded_buffer = vision_backend.decode([next_token]) |
| if decoded_buffer: |
| yield {"type": "token", "text": decoded_buffer} |
|
|
| for step in range(max_tokens): |
| if next_token in eos_set: |
| break |
|
|
| stop_hit = any(s in decoded_buffer for s in VISION_STOP_SEQUENCES) |
| if stop_hit: |
| break |
|
|
| try: |
| logits = vision_backend.feed_tokens_text([next_token]) |
| except Exception as e: |
| yield {"type": "error", "message": str(e)} |
| return |
|
|
| last_logits = logits[0, -1] |
|
|
| if recent: |
| last_np = np.array(last_logits.astype(mx.float32)) |
| for tid in set(recent): |
| if last_np[tid] > 0: |
| last_np[tid] /= 1.15 |
| else: |
| last_np[tid] *= 1.15 |
| last_logits = mx.array(last_np) |
|
|
| if temperature <= 0: |
| next_token = int(mx.argmax(last_logits).item()) |
| else: |
| probs = mx.softmax(last_logits / temperature, axis=-1) |
| next_token = int(mx.random.categorical(mx.log(probs + 1e-10)).item()) |
|
|
| generated.append(next_token) |
| recent.append(next_token) |
|
|
| full = vision_backend.decode(generated) |
| new_chunk = full[len(decoded_buffer):] |
| decoded_buffer = full |
| if new_chunk: |
| yield {"type": "token", "text": new_chunk} |
|
|
| text = re.sub(r"<result>.*?</result>", "", decoded_buffer, flags=re.DOTALL) |
| call = parse_first_tool(text) |
|
|
| if call is None: |
| final_answer = text.strip() |
| yield {"type": "final", "text": final_answer} |
| yield {"type": "done"} |
| return |
|
|
| tool_name = call["tool"] |
| tool_args = call.get("content") or call.get("slot") or str(call.get("mask_id", "")) |
| yield {"type": "tool_call", "tool": tool_name, "args": str(tool_args)[:200]} |
|
|
| result = execute_tool( |
| call, allow_write=allow_write, falcon_tools=vision_backend.falcon_tools |
| ) |
|
|
| if isinstance(result, dict): |
| if "masks" in result: |
| trimmed = [] |
| for m in result["masks"]: |
| trimmed.append({k: v for k, v in m.items() if not k.startswith("_")}) |
| result_for_llm = {**result, "masks": trimmed} |
| else: |
| result_for_llm = result |
| result_str = json.dumps(result_for_llm) |
| display_str = json.dumps(result_for_llm, indent=2)[:1500] |
| else: |
| result_str = str(result) |
| display_str = result_str[:1500] |
|
|
| yield {"type": "tool_result", "result": display_str} |
|
|
| continuation = ( |
| f"<result>{result_str}</result>\n\n" |
| f"Now either call another tool or give a final answer (no tags):" |
| ) |
|
|
| try: |
| cont_tokens = vision_backend.encode(continuation) |
| logits = vision_backend.feed_tokens_text(cont_tokens) |
| except Exception as e: |
| yield {"type": "error", "message": str(e)} |
| return |
|
|
| last_logits = logits[0, -1] |
| if temperature <= 0: |
| next_token = int(mx.argmax(last_logits).item()) |
| else: |
| probs = mx.softmax(last_logits / temperature, axis=-1) |
| next_token = int(mx.random.categorical(mx.log(probs + 1e-10)).item()) |
|
|
| yield {"type": "final", "text": "(iteration limit reached)"} |
| yield {"type": "done"} |
|
|
|
|
| |
| |
| |
|
|
|
|
| def run_agent_turn_stream(backend, user_question, max_iterations=6, |
| max_tokens=300, temperature=0.5, allow_write=False): |
| """Run one user turn as a generator yielding events. |
| |
| Each yield is a dict with a 'type' key. Types: |
| - {"type": "step_start", "step": N, "max": M} |
| - {"type": "token", "text": "..."} β streaming tokens |
| - {"type": "tool_call", "tool": "ls", "args": "."} |
| - {"type": "tool_result", "result": "..."} |
| - {"type": "final", "text": "..."} β final answer |
| - {"type": "done"} |
| - {"type": "error", "message": "..."} |
| """ |
| backend.reset() |
| prompt = SYSTEM_PROMPT + " " + user_question |
| final_answer = None |
| continuation = None |
|
|
| |
| |
| chunk_buffer = [] |
|
|
| def on_chunk(text): |
| chunk_buffer.append(text) |
|
|
| for step in range(max_iterations): |
| yield {"type": "step_start", "step": step + 1, "max": max_iterations} |
| chunk_buffer.clear() |
|
|
| try: |
| text = backend.generate_until_stop( |
| prompt if step == 0 else continuation, |
| stop_sequences=STOP_SEQUENCES + ["</result>"], |
| max_tokens=max_tokens, |
| temperature=temperature, |
| on_chunk=on_chunk, |
| ) |
| except Exception as e: |
| yield {"type": "error", "message": str(e)} |
| return |
|
|
| |
| |
| |
| for c in chunk_buffer: |
| yield {"type": "token", "text": c} |
|
|
| |
| text = re.sub(r"<result>.*?</result>", "", text, flags=re.DOTALL) |
|
|
| call = parse_first_tool(text) |
|
|
| if call is None: |
| final_answer = text.strip() |
| yield {"type": "final", "text": final_answer} |
| yield {"type": "done"} |
| return |
|
|
| |
| tool_name = call["tool"] |
| tool_arg = call.get("content", "") |
| yield {"type": "tool_call", "tool": tool_name, "args": tool_arg[:200]} |
|
|
| result = execute_tool(call, allow_write=allow_write) |
| yield {"type": "tool_result", "result": result[:2000]} |
|
|
| continuation = ( |
| f"<result>\n{result}\n</result>\n\n" |
| f"Now either call another tool or give a final answer (no tags):" |
| ) |
|
|
| yield {"type": "final", "text": "(iteration limit reached)"} |
| yield {"type": "done"} |
|
|
|
|
| def run_agent_turn(backend, user_question, max_iterations=6, max_tokens=300, |
| temperature=0.5, allow_write=False, verbose=True): |
| """Synchronous wrapper around run_agent_turn_stream β used by the REPL.""" |
| final = "(no answer)" |
| for event in run_agent_turn_stream( |
| backend, user_question, |
| max_iterations=max_iterations, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| allow_write=allow_write, |
| ): |
| et = event["type"] |
| if et == "step_start" and verbose: |
| print(f"\n [Step {event['step']}/{event['max']}]", end=" ", flush=True) |
| elif et == "token" and verbose: |
| print(event["text"], end="", flush=True) |
| elif et == "tool_call" and verbose: |
| preview = event["args"][:60].replace("\n", " ") |
| print(f"\n β tool: <{event['tool']}> {preview}...") |
| elif et == "tool_result" and verbose: |
| preview = event["result"][:120].replace("\n", " | ") |
| print(f" β {preview}{'...' if len(event['result']) > 120 else ''}") |
| elif et == "final": |
| final = event["text"] |
| elif et == "error" and verbose: |
| print(f"\n ERROR: {event['message']}") |
| return final |
|
|
|
|
| def agent_loop(backend, max_iterations=6, max_tokens=400, allow_write=False): |
| print() |
| print("=" * 60) |
| print(" mac-tensor agent (v2 β stop sequences + repetition penalty)") |
| print(f" Model: {backend.model_key} | Nodes: {len(backend.node_urls)}") |
| print(f" Tools: read, ls, shell, search, python" + (", write" if allow_write else "")) |
| print("=" * 60) |
| print("\nType your question. 'reset' to clear, 'quit' to exit.") |
| print("-" * 60) |
|
|
| while True: |
| try: |
| user_input = input("\nYou: ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print("\nBye!") |
| return |
|
|
| if not user_input: |
| continue |
| if user_input.lower() in ("quit", "exit", "q"): |
| print("Bye!") |
| return |
| if user_input.lower() == "reset": |
| backend.reset() |
| print("Context cleared.") |
| continue |
|
|
| answer = run_agent_turn( |
| backend, user_input, |
| max_iterations=max_iterations, |
| max_tokens=max_tokens, |
| allow_write=allow_write, |
| ) |
|
|
| print(f"\n\nAgent: {answer}\n") |
|
|
|
|
| def main(args): |
| if not args.nodes: |
| print("Error: --nodes is required") |
| sys.exit(1) |
|
|
| node_urls = [u.strip() for u in args.nodes.split(",")] |
| model_key = args.model or "gemma4" |
|
|
| print(f"Loading {model_key} distributed engine...") |
| backend = AgentBackend(model_key=model_key, node_urls=node_urls) |
| backend.load() |
|
|
| agent_loop( |
| backend, |
| max_iterations=args.max_iterations or 6, |
| max_tokens=args.max_tokens or 300, |
| allow_write=args.write, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", choices=["qwen35", "gemma4"], default="gemma4") |
| parser.add_argument("--nodes", required=True) |
| parser.add_argument("--max-iterations", type=int, default=6) |
| parser.add_argument("--max-tokens", type=int, default=300) |
| parser.add_argument("--write", action="store_true") |
| args = parser.parse_args() |
| main(args) |
|
|