ruslanmv commited on
Commit
fd95484
·
1 Parent(s): 4740c16

First working version chat

Browse files
app/core/config.py CHANGED
@@ -1,14 +1,14 @@
1
  from __future__ import annotations
2
- import os
3
- import yaml
4
  from pydantic import BaseModel, AnyHttpUrl
5
  from typing import Optional
6
 
7
  class ModelCfg(BaseModel):
8
- name: str = "meta-llama/Meta-Llama-3-8B-Instruct"
9
  fallback: str = "mistralai/Mistral-7B-Instruct-v0.2"
10
  max_new_tokens: int = 256
11
  temperature: float = 0.2
 
12
 
13
  class LimitsCfg(BaseModel):
14
  rate_per_min: int = 60
@@ -30,26 +30,24 @@ class Settings(BaseModel):
30
  rag: RagCfg = RagCfg()
31
  matrixhub: MatrixHubCfg = MatrixHubCfg()
32
  security: SecurityCfg = SecurityCfg()
 
 
33
 
34
  @staticmethod
35
  def load() -> Settings:
36
- """Loads settings from YAML and overrides with environment variables."""
37
  path = os.getenv("SETTINGS_FILE", "configs/settings.yaml")
38
  data = {}
39
  if os.path.exists(path):
40
  with open(path, "r", encoding="utf-8") as f:
41
  data = yaml.safe_load(f) or {}
42
-
43
  settings = Settings.model_validate(data)
44
 
45
- # Environment variable overrides
46
- if "MODEL_NAME" in os.environ:
47
- settings.model.name = os.environ["MODEL_NAME"]
48
- if "INDEX_DATASET" in os.environ:
49
- settings.rag.index_dataset = os.environ["INDEX_DATASET"]
50
- if "RATE_LIMITS" in os.environ:
51
- settings.limits.rate_per_min = int(os.environ["RATE_LIMITS"])
52
- if "ADMIN_TOKEN" in os.environ:
53
- settings.security.admin_token = os.environ["ADMIN_TOKEN"]
54
-
55
  return settings
 
1
  from __future__ import annotations
2
+ import os, yaml
 
3
  from pydantic import BaseModel, AnyHttpUrl
4
  from typing import Optional
5
 
6
  class ModelCfg(BaseModel):
7
+ name: str = "HuggingFaceH4/zephyr-7b-beta"
8
  fallback: str = "mistralai/Mistral-7B-Instruct-v0.2"
9
  max_new_tokens: int = 256
10
  temperature: float = 0.2
11
+ provider: Optional[str] = None # NEW
12
 
13
  class LimitsCfg(BaseModel):
14
  rate_per_min: int = 60
 
30
  rag: RagCfg = RagCfg()
31
  matrixhub: MatrixHubCfg = MatrixHubCfg()
32
  security: SecurityCfg = SecurityCfg()
33
+ chat_backend: str = "router" # NEW (reserved)
34
+ chat_stream: bool = True # NEW
35
 
36
  @staticmethod
37
  def load() -> Settings:
 
38
  path = os.getenv("SETTINGS_FILE", "configs/settings.yaml")
39
  data = {}
40
  if os.path.exists(path):
41
  with open(path, "r", encoding="utf-8") as f:
42
  data = yaml.safe_load(f) or {}
 
43
  settings = Settings.model_validate(data)
44
 
45
+ # Env overrides
46
+ if "MODEL_NAME" in os.environ: settings.model.name = os.environ["MODEL_NAME"]
47
+ if "MODEL_FALLBACK" in os.environ: settings.model.fallback = os.environ["MODEL_FALLBACK"]
48
+ if "MODEL_PROVIDER" in os.environ: settings.model.provider = os.environ["MODEL_PROVIDER"]
49
+ if "ADMIN_TOKEN" in os.environ: settings.security.admin_token = os.environ["ADMIN_TOKEN"]
50
+ if "RATE_LIMITS" in os.environ: settings.limits.rate_per_min = int(os.environ["RATE_LIMITS"])
51
+ if "HF_CHAT_BACKEND" in os.environ: settings.chat_backend = os.environ["HF_CHAT_BACKEND"].strip().lower()
52
+ if "CHAT_STREAM" in os.environ: settings.chat_stream = os.environ["CHAT_STREAM"].lower() in ("1","true","yes","on")
 
 
53
  return settings
app/core/inference/client.py CHANGED
@@ -1,94 +1,144 @@
1
- import os
2
- import logging
3
- import httpx
4
- from typing import Optional, Any, Union
5
- from tenacity import retry, stop_after_attempt, wait_exponential
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
- class HFClient:
10
- def __init__(self, model: str, fallback: Optional[str] = None, timeout: int = 20):
11
- self.model = model
12
- self.fallback = fallback
13
- self.timeout = timeout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- token = os.getenv("HF_TOKEN")
16
- if not token:
17
- raise ValueError("HF_TOKEN environment variable is not set. Put it in .env or export it before starting.")
18
 
19
- self.headers = {
20
- "Authorization": f"Bearer {token}",
21
- "Accept": "application/json",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  }
23
- self.api_base = "https://api-inference.huggingface.co/models"
 
 
24
 
25
- async def _post(self, model: str, payload: dict) -> Any:
26
- url = f"{self.api_base}/{model}"
27
- # wait_for_model=true is helpful if the container is cold
28
- params = {"wait_for_model": "true"}
29
- async with httpx.AsyncClient(timeout=self.timeout) as client:
30
- r = await client.post(url, headers=self.headers, json=payload, params=params)
31
- r.raise_for_status()
32
- return r.json()
33
 
34
- @staticmethod
35
- def _extract_text(data: Union[dict, list, str]) -> str:
36
- # HF can return list[{"generated_text": "..."}] or {"generated_text": "..."} or str
37
- if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]:
38
- return str(data[0]["generated_text"])
39
- if isinstance(data, dict) and "generated_text" in data:
40
- return str(data["generated_text"])
41
- if isinstance(data, str):
42
- return data
43
- # Some serverless returns {"error": "..."} with 200—handle gently
44
- if isinstance(data, dict) and "error" in data:
45
- raise RuntimeError(f"Hugging Face error: {data['error']}")
46
- raise RuntimeError(f"Unexpected HF response format: {data!r}")
47
 
48
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=8))
49
- async def _generate_once(self, model: str, prompt: str, max_new_tokens: int, temperature: float) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  payload = {
51
- "inputs": prompt,
52
- "parameters": {
53
- "max_new_tokens": max(1, int(max_new_tokens)),
54
- "temperature": float(max(temperature, 0.01)),
55
- "return_full_text": False,
56
- },
57
  }
58
- data = await self._post(model, payload)
59
- return self._extract_text(data)
 
 
 
 
 
 
 
 
 
 
60
 
61
- async def generate(self, prompt: str, max_new_tokens: int, temperature: float) -> str:
62
- # Try primary
63
  try:
64
- return await self._generate_once(self.model, prompt, max_new_tokens, temperature)
65
- except httpx.HTTPStatusError as e:
66
- code = e.response.status_code
67
- body = e.response.text
68
- logger.error("HTTP error from HF API for model %s: %s", self.model, body)
69
- # If not authorized / not found / gated, try fallback if defined
70
- if code in (401, 403, 404) and self.fallback and self.fallback != self.model:
71
- logger.warning("Falling back to model %s due to %s", self.fallback, code)
72
- try:
73
- return await self._generate_once(self.fallback, prompt, max_new_tokens, temperature)
74
- except Exception:
75
- # re-raise original meaningful error below
76
- pass
77
- # Give a readable hint for common cause with Llama
78
- if code in (401, 403, 404) and "meta-llama" in self.model.lower():
79
- raise PermissionError(
80
- "Hugging Face returned 404/403 for a gated model. "
81
- "Make sure your HF account accepted the model license and your HF_TOKEN has access. "
82
- f"Model={self.model}"
83
- ) from e
84
- raise
85
  except Exception as e:
86
- logger.error("Failed to call HF API for model %s: %s", self.model, e)
87
- # Try fallback for transient or parsing errors
88
- if self.fallback and self.fallback != self.model:
89
- try:
90
- logger.warning("Falling back to model %s due to generic failure", self.fallback)
91
- return await self._generate_once(self.fallback, prompt, max_new_tokens, temperature)
92
- except Exception:
93
- pass
94
- raise
 
1
+ import os, json, time, logging
2
+ from typing import Dict, List, Optional, Iterator, Tuple
3
+
4
+ import requests
 
5
 
6
  logger = logging.getLogger(__name__)
7
 
8
+ ROUTER_URL = "https://router.huggingface.co/v1/chat/completions"
9
+
10
+ def _require_token() -> str:
11
+ tok = os.getenv("HF_TOKEN")
12
+ if not tok:
13
+ raise ValueError("HF_TOKEN is not set. Put it in .env or export it before starting.")
14
+ return tok
15
+
16
+ def _model_with_provider(model: str, provider: Optional[str]) -> str:
17
+ if provider and ":" not in model:
18
+ return f"{model}:{provider}"
19
+ return model
20
+
21
+ def _mk_messages(system_prompt: Optional[str], user_text: str) -> List[Dict[str, str]]:
22
+ msgs: List[Dict[str, str]] = []
23
+ if system_prompt:
24
+ msgs.append({"role": "system", "content": system_prompt})
25
+ msgs.append({"role": "user", "content": user_text})
26
+ return msgs
27
 
28
+ def _timeout_tuple(connect: float = 10.0, read: float = 60.0) -> Tuple[float, float]:
29
+ # requests timeout is (connect, read)
30
+ return (connect, read)
31
 
32
+ class RouterRequestsClient:
33
+ """
34
+ Simple requests-only client for HF Router Chat Completions.
35
+ Supports non-streaming (returns str) and streaming (yields token strings).
36
+ """
37
+ def __init__(self, model: str, fallback: Optional[str] = None, provider: Optional[str] = None,
38
+ max_retries: int = 2, connect_timeout: float = 10.0, read_timeout: float = 60.0):
39
+ self.model = model
40
+ self.fallback = fallback if fallback != model else None
41
+ self.provider = provider
42
+ self.headers = {"Authorization": f"Bearer {_require_token()}"}
43
+ self.max_retries = max(0, int(max_retries))
44
+ self.timeout = _timeout_tuple(connect_timeout, read_timeout)
45
+
46
+ # -------- Non-stream (single text) --------
47
+ def chat_nonstream(self, system_prompt: Optional[str], user_text: str,
48
+ max_tokens: int, temperature: float) -> str:
49
+ payload = {
50
+ "model": _model_with_provider(self.model, self.provider),
51
+ "messages": _mk_messages(system_prompt, user_text),
52
+ "temperature": float(temperature),
53
+ "max_tokens": int(max_tokens),
54
+ "stream": False,
55
  }
56
+ text, ok = self._try_once(payload)
57
+ if ok:
58
+ return text
59
 
60
+ # fallback (if configured)
61
+ if self.fallback:
62
+ payload["model"] = _model_with_provider(self.fallback, self.provider)
63
+ text, ok = self._try_once(payload)
64
+ if ok:
65
+ return text
 
 
66
 
67
+ raise RuntimeError(f"Chat non-stream failed: model={self.model} fallback={self.fallback}")
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ def _try_once(self, payload: dict) -> Tuple[str, bool]:
70
+ last_err = None
71
+ for attempt in range(self.max_retries + 1):
72
+ try:
73
+ r = requests.post(ROUTER_URL, headers=self.headers, json=payload, timeout=self.timeout)
74
+ if r.status_code >= 400:
75
+ logger.error("Router error %s: %s", r.status_code, r.text)
76
+ last_err = RuntimeError(f"{r.status_code}: {r.text}")
77
+ # do not hard-spin; brief pause
78
+ time.sleep(min(1.5 * (attempt + 1), 3.0))
79
+ continue
80
+ data = r.json()
81
+ return data["choices"][0]["message"]["content"], True
82
+ except Exception as e:
83
+ logger.error("Router request failure: %s", e)
84
+ last_err = e
85
+ time.sleep(min(1.5 * (attempt + 1), 3.0))
86
+ if last_err:
87
+ logger.error("Router exhausted retries: %s", last_err)
88
+ return "", False
89
+
90
+ # -------- Streaming (yield token deltas) --------
91
+ def chat_stream(self, system_prompt: Optional[str], user_text: str,
92
+ max_tokens: int, temperature: float) -> Iterator[str]:
93
  payload = {
94
+ "model": _model_with_provider(self.model, self.provider),
95
+ "messages": _mk_messages(system_prompt, user_text),
96
+ "temperature": float(temperature),
97
+ "max_tokens": int(max_tokens),
98
+ "stream": True,
 
99
  }
100
+ # primary
101
+ ok = False
102
+ for token in self._stream_once(payload):
103
+ ok = True
104
+ yield token
105
+ if ok:
106
+ return
107
+ # fallback stream if primary produced nothing (or died immediately)
108
+ if self.fallback:
109
+ payload["model"] = _model_with_provider(self.fallback, self.provider)
110
+ for token in self._stream_once(payload):
111
+ yield token
112
 
113
+ def _stream_once(self, payload: dict) -> Iterator[str]:
 
114
  try:
115
+ with requests.post(ROUTER_URL, headers=self.headers, json=payload, stream=True, timeout=self.timeout) as r:
116
+ if r.status_code >= 400:
117
+ logger.error("Router stream error %s: %s", r.status_code, r.text)
118
+ return
119
+ for line in r.iter_lines(decode_unicode=True):
120
+ if not line:
121
+ continue
122
+ if not line.startswith("data:"):
123
+ continue
124
+ data = line[len("data:"):].strip()
125
+ if data == "[DONE]":
126
+ return
127
+ try:
128
+ obj = json.loads(data)
129
+ # OpenAI-style: delta tokens
130
+ delta = obj["choices"][0]["delta"].get("content", "")
131
+ if delta:
132
+ yield delta
133
+ except Exception as e:
134
+ logger.warning("Stream JSON parse issue: %s | line=%r", e, line)
135
+ continue
136
  except Exception as e:
137
+ logger.error("Stream request failure: %s", e)
138
+ return
139
+
140
+ # -------- Planning (non-stream) --------
141
+ def plan_nonstream(self, system_prompt: str, user_text: str,
142
+ max_tokens: int, temperature: float) -> str:
143
+ """Use same chat/completions but always non-stream for planning."""
144
+ return self.chat_nonstream(system_prompt, user_text, max_tokens, temperature)
 
app/services/chat_service.py CHANGED
@@ -1,6 +1,6 @@
1
  from __future__ import annotations
2
  from ..core.config import Settings
3
- from ..core.inference.client import HFClient
4
 
5
  SYSTEM_PROMPT = (
6
  "You are MATRIX-AI, a concise, helpful assistant for the Matrix EcoSystem. "
@@ -10,16 +10,27 @@ SYSTEM_PROMPT = (
10
  class ChatService:
11
  def __init__(self, settings: Settings):
12
  self.settings = settings
13
- self.client = HFClient(
14
  model=settings.model.name,
15
  fallback=settings.model.fallback,
 
 
 
 
16
  )
17
 
18
  async def answer(self, query: str) -> str:
19
- prompt = f"{SYSTEM_PROMPT}\n\nUser: {query}\nAssistant:"
20
- text = await self.client.generate(
21
- prompt=prompt,
22
- max_new_tokens=self.settings.model.max_new_tokens,
 
 
 
 
 
 
 
 
23
  temperature=self.settings.model.temperature,
24
  )
25
- return (text or "").strip()
 
1
  from __future__ import annotations
2
  from ..core.config import Settings
3
+ from ..core.inference.client import RouterRequestsClient
4
 
5
  SYSTEM_PROMPT = (
6
  "You are MATRIX-AI, a concise, helpful assistant for the Matrix EcoSystem. "
 
10
  class ChatService:
11
  def __init__(self, settings: Settings):
12
  self.settings = settings
13
+ self.client = RouterRequestsClient(
14
  model=settings.model.name,
15
  fallback=settings.model.fallback,
16
+ provider=settings.model.provider,
17
+ max_retries=2,
18
+ connect_timeout=10.0,
19
+ read_timeout=60.0,
20
  )
21
 
22
  async def answer(self, query: str) -> str:
23
+ # non-stream (compatible with current UI)
24
+ return self.client.chat_nonstream(
25
+ SYSTEM_PROMPT, query,
26
+ max_tokens=self.settings.model.max_new_tokens,
27
+ temperature=self.settings.model.temperature,
28
+ )
29
+
30
+ # Expose a generator for streaming endpoints
31
+ def stream_answer(self, query: str):
32
+ return self.client.chat_stream(
33
+ SYSTEM_PROMPT, query,
34
+ max_tokens=self.settings.model.max_new_tokens,
35
  temperature=self.settings.model.temperature,
36
  )
 
app/services/plan_service.py CHANGED
@@ -1,56 +1,195 @@
 
 
 
1
  import hashlib
2
  import json
3
  import logging
4
  from pathlib import Path
 
 
5
  from ..core.schema import PlanRequest, PlanResponse
6
  from ..core.config import Settings
7
- from ..core.inference.client import HFClient
8
  from ..core.redact import redact
 
9
 
10
  logger = logging.getLogger(__name__)
11
- _PROMPT_TEMPLATE: str | None = None
 
 
 
 
 
 
 
 
 
 
12
 
13
  def _get_prompt_template() -> str:
14
- global _PROMPT_TEMPLATE
15
- if _PROMPT_TEMPLATE is None:
16
- try:
17
- path = Path(__file__).parent.parent / "core/prompts/plan.txt"
18
- _PROMPT_TEMPLATE = path.read_text(encoding="utf-8")
19
- except FileNotFoundError:
20
- logger.error("FATAL: core/prompts/plan.txt not found.")
21
- _PROMPT_TEMPLATE = "Generate a JSON plan with keys: plan_id, steps, risk, explanation."
22
- return _PROMPT_TEMPLATE
23
-
24
- def _create_final_prompt(req: PlanRequest) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  template = _get_prompt_template()
26
- context_str = f"Context:\n- app_id: {req.context.app_id}\n- symptoms: {', '.join(req.context.symptoms)}\n- lkg_version: {req.context.lkg or 'N/A'}\n- constraints: max_steps={req.constraints.max_steps}, risk={req.constraints.risk}"
27
  safe_context = redact(context_str)
28
- return f"{template}\n\n{safe_context}\n\nJSON Response:"
29
 
30
- def _parse_llm_output(raw_output: str, context_str: str) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  try:
32
- start = raw_output.find('{')
33
- end = raw_output.rfind('}')
34
- if start != -1 and end != -1 and end > start:
35
- json_str = raw_output[start:end+1]
36
- return json.loads(json_str)
37
- raise ValueError("No valid JSON object found in output.")
38
- except (json.JSONDecodeError, ValueError) as e:
39
- logger.warning(f"LLM output parsing failed: {e}. Applying safe fallback plan.")
 
 
 
 
 
 
 
 
 
 
 
 
40
  return {
41
- "plan_id": hashlib.md5(context_str.encode()).hexdigest()[:12],
42
- "steps": ["Pin to the last-known-good (LKG) version and re-run health probes."],
 
 
43
  "risk": "low",
44
- "explanation": "Fallback plan: A safe default was applied due to a model output parsing error."
 
 
45
  }
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  async def generate_plan(req: PlanRequest, settings: Settings) -> PlanResponse:
48
- final_prompt = _create_final_prompt(req)
49
- client = HFClient(model=settings.model.name)
50
- raw_response = await client.generate(
51
- prompt=final_prompt,
52
- max_new_tokens=settings.model.max_new_tokens,
53
- temperature=settings.model.temperature,
54
- )
55
- parsed_data = _parse_llm_output(raw_response, final_prompt)
56
- return PlanResponse.model_validate(parsed_data)
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
  import hashlib
5
  import json
6
  import logging
7
  from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
  from ..core.schema import PlanRequest, PlanResponse
11
  from ..core.config import Settings
 
12
  from ..core.redact import redact
13
+ from ..core.inference.client import RouterRequestsClient
14
 
15
  logger = logging.getLogger(__name__)
16
+
17
+ # ----------------------------
18
+ # Prompts
19
+ # ----------------------------
20
+ SYSTEM_PLANNER = (
21
+ "You are MATRIX-AI Planner. Produce a short, safe JSON plan. "
22
+ "Bounded steps, minimal risk, and explain briefly."
23
+ )
24
+
25
+ _PROMPT_TEMPLATE_CACHE: Optional[str] = None
26
+
27
 
28
  def _get_prompt_template() -> str:
29
+ """
30
+ Load core/prompts/plan.txt once (cached).
31
+ Fallback to a minimal instruction if missing.
32
+ """
33
+ global _PROMPT_TEMPLATE_CACHE
34
+ if _PROMPT_TEMPLATE_CACHE is not None:
35
+ return _PROMPT_TEMPLATE_CACHE
36
+
37
+ try:
38
+ path = Path(__file__).parent.parent / "core" / "prompts" / "plan.txt"
39
+ _PROMPT_TEMPLATE_CACHE = path.read_text(encoding="utf-8")
40
+ except FileNotFoundError:
41
+ logger.error("FATAL: core/prompts/plan.txt not found. Using fallback template.")
42
+ _PROMPT_TEMPLATE_CACHE = (
43
+ "Generate a JSON plan with keys: plan_id, steps, risk, explanation. "
44
+ "Keep steps short, safe, and auditable."
45
+ )
46
+ return _PROMPT_TEMPLATE_CACHE
47
+
48
+
49
+ def _render_context(req: PlanRequest) -> str:
50
+ """
51
+ Render a compact context string from the request.
52
+ (Matches your earlier shape: app_id, symptoms, lkg, constraints.)
53
+ """
54
+ app_id = getattr(req.context, "app_id", None) or getattr(req.context, "entity_uid", "unknown")
55
+ symptoms = getattr(req.context, "symptoms", []) or []
56
+ lkg = getattr(req.context, "lkg", None) or getattr(req.context, "lkg_version", None) or "N/A"
57
+
58
+ max_steps = getattr(req.constraints, "max_steps", 3)
59
+ risk = getattr(req.constraints, "risk", "low")
60
+
61
+ return (
62
+ "Context:\n"
63
+ f"- app_id: {app_id}\n"
64
+ f"- symptoms: {', '.join(symptoms) if symptoms else 'none'}\n"
65
+ f"- lkg_version: {lkg}\n"
66
+ f"- constraints: max_steps={max_steps}, risk={risk}"
67
+ )
68
+
69
+
70
+ def _build_prompt(req: PlanRequest) -> str:
71
+ """
72
+ Compose final prompt with system guidance + template + redacted context.
73
+ """
74
  template = _get_prompt_template()
75
+ context_str = _render_context(req)
76
  safe_context = redact(context_str)
 
77
 
78
+ # You can tweak ordering if desired; this is clear and stable.
79
+ return f"{SYSTEM_PLANNER}\n\n{template}\n\n{safe_context}\n\nJSON Response:"
80
+
81
+
82
+ # ----------------------------
83
+ # Output parsing
84
+ # ----------------------------
85
+ def _extract_json_block(text: str) -> Dict[str, Any]:
86
+ """
87
+ Try hard to recover a JSON object from LLM text.
88
+ Supports ```json fences and "first { ... last }".
89
+ Raises ValueError if no JSON object can be extracted.
90
+ """
91
+ s = text.strip()
92
+
93
+ # Fenced block: ```json ... ```
94
+ if "```" in s:
95
+ fence_start = s.find("```")
96
+ lang_tag = s.find("\n", fence_start + 3)
97
+ if lang_tag != -1:
98
+ fence_close = s.find("```", lang_tag + 1)
99
+ if fence_close != -1:
100
+ fenced = s[lang_tag + 1 : fence_close].strip()
101
+ return json.loads(fenced)
102
+
103
+ # Plain: first "{" to last "}"
104
+ first = s.find("{")
105
+ last = s.rfind("}")
106
+ if first != -1 and last != -1 and last > first:
107
+ candidate = s[first : last + 1]
108
+ return json.loads(candidate)
109
+
110
+ raise ValueError("No valid JSON object found in output.")
111
+
112
+
113
+ def _safe_parse_or_fallback(raw_output: str, context_for_id: str) -> Dict[str, Any]:
114
+ """
115
+ Parse the model output into a dict, or return a safe fallback plan.
116
+ """
117
  try:
118
+ obj = _extract_json_block(raw_output)
119
+ if not isinstance(obj, dict):
120
+ raise ValueError("Top-level JSON is not an object.")
121
+
122
+ # Minimal normalization: ensure keys exist
123
+ if "plan_id" not in obj or not obj["plan_id"]:
124
+ obj["plan_id"] = hashlib.md5(context_for_id.encode()).hexdigest()[:12]
125
+ if "steps" not in obj or not obj["steps"]:
126
+ obj["steps"] = [
127
+ "Pin to the last-known-good (LKG) version and re-run health probes."
128
+ ]
129
+ if "risk" not in obj or not obj["risk"]:
130
+ obj["risk"] = "low"
131
+ if "explanation" not in obj or not obj["explanation"]:
132
+ obj["explanation"] = "Autofilled explanation."
133
+
134
+ return obj
135
+
136
+ except Exception as e:
137
+ logger.warning("LLM output parsing failed: %s. Applying fallback plan.", e)
138
  return {
139
+ "plan_id": hashlib.md5(context_for_id.encode()).hexdigest()[:12],
140
+ "steps": [
141
+ "Pin to the last-known-good (LKG) version and re-run health probes."
142
+ ],
143
  "risk": "low",
144
+ "explanation": (
145
+ "Fallback plan: A safe default was applied due to a model output parsing error."
146
+ ),
147
  }
148
 
149
+
150
+ # ----------------------------
151
+ # Service (requests-only, non-stream)
152
+ # ----------------------------
153
+ class PlanService:
154
+ """
155
+ Planner uses HF Router (requests-only). Always non-stream for plan generation.
156
+ """
157
+
158
+ def __init__(self, settings: Settings):
159
+ self.settings = settings
160
+ self.client = RouterRequestsClient(
161
+ model=settings.model.name,
162
+ fallback=settings.model.fallback,
163
+ provider=settings.model.provider,
164
+ max_retries=2,
165
+ connect_timeout=10.0,
166
+ read_timeout=60.0,
167
+ )
168
+
169
+ async def generate(self, req: PlanRequest) -> PlanResponse:
170
+ """
171
+ Build prompt -> call Router (non-stream) -> robustly parse -> PlanResponse.
172
+ """
173
+ final_prompt = _build_prompt(req)
174
+ # run the blocking requests call in a worker thread to avoid blocking the event loop
175
+ raw_text = await asyncio.to_thread(
176
+ self.client.plan_nonstream,
177
+ SYSTEM_PLANNER,
178
+ final_prompt,
179
+ self.settings.model.max_new_tokens,
180
+ self.settings.model.temperature,
181
+ )
182
+ parsed = _safe_parse_or_fallback(raw_text, final_prompt)
183
+ return PlanResponse.model_validate(parsed)
184
+
185
+
186
+ # ----------------------------
187
+ # Back-compat function (keeps existing imports working)
188
+ # ----------------------------
189
  async def generate_plan(req: PlanRequest, settings: Settings) -> PlanResponse:
190
+ """
191
+ Backward-compatible entry point:
192
+ previous code called services.plan.generate_plan(...)
193
+ """
194
+ service = PlanService(settings)
195
+ return await service.generate(req)
 
 
 
configs/settings.yaml CHANGED
@@ -1,21 +1,24 @@
1
  model:
2
- name: "HuggingFaceH4/zephyr-7b-beta" # good balance of speed and capability
3
- #name: "mistralai/Mistral-7B-Instruct-v0.2" # capable, open, but large
4
- #fallback: "HuggingFaceH4/zephyr-7b-beta" # smaller, faster, but less capable
5
- fallback: "microsoft/Phi-3-mini-4k-instruct" # smaller, faster, but less capable
6
  max_new_tokens: 256
7
  temperature: 0.2
8
 
 
 
 
 
9
  limits:
10
  rate_per_min: 60
11
  cache_size: 256
12
 
13
  rag:
14
- index_dataset: "" # e.g., "your-username/matrix-ai-index"
15
  top_k: 4
16
 
17
  matrixhub:
18
  base_url: "https://api.matrixhub.io"
19
 
20
  security:
21
- admin_token: "" # Should be set via env var
 
1
  model:
2
+ name: "HuggingFaceH4/zephyr-7b-beta"
3
+ fallback: "microsoft/Phi-3-mini-4k-instruct"
4
+ provider: "featherless-ai" # NEW: makes "model:provider" for Router
 
5
  max_new_tokens: 256
6
  temperature: 0.2
7
 
8
+ # Chat backend + mode (requests → Router only)
9
+ chat_backend: "router" # reserved (future multi-backend)
10
+ chat_stream: true # default streaming behavior for /v1/chat/stream
11
+
12
  limits:
13
  rate_per_min: 60
14
  cache_size: 256
15
 
16
  rag:
17
+ index_dataset: ""
18
  top_k: 4
19
 
20
  matrixhub:
21
  base_url: "https://api.matrixhub.io"
22
 
23
  security:
24
+ admin_token: ""