ashu1069 commited on
Commit
e8956ff
·
1 Parent(s): b6c01d1

sync: matter cache + breaker, runtime threading lock

Browse files
Files changed (2) hide show
  1. matter/impact.py +67 -26
  2. transformers_runtime.py +33 -18
matter/impact.py CHANGED
@@ -31,9 +31,41 @@ from typing import Any
31
  from matter.passport import Environmental, Passport, Value
32
 
33
 
 
 
34
  CLIMATIQ_BASE_URL = "https://api.climatiq.io"
35
  CLIMATIQ_API_KEY_ENV = "CLIMATIQ_API_KEY"
36
- CLIMATIQ_TIMEOUT_S = 4.0 # tight — caller wants fast inference, not perfect data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  @dataclass(frozen=True)
@@ -141,26 +173,40 @@ def estimate_climatiq(
141
  query = _CLIMATIQ_QUERIES.get((head_name, identity_class, next_best_action))
142
  if query is None:
143
  return None
 
 
 
144
  try:
145
  # Lazy import — httpx is in the package deps but optional at module load
146
  import httpx
147
  with httpx.Client(timeout=timeout) as client:
148
- # 1. Search for the matching factor
149
- search = client.get(
150
- f"{CLIMATIQ_BASE_URL}/data/v1/search",
151
- params={"query": query, "results_per_page": 1},
152
- headers={"Authorization": f"Bearer {api_key}"},
153
- )
154
- if search.status_code != 200:
155
- return None
156
- results = (search.json() or {}).get("results") or []
157
- if not results:
158
- return None
159
- top = results[0]
160
- activity_id = top.get("activity_id")
161
- data_version = top.get("data_version", "^21")
162
- if not activity_id:
163
- return None
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  # 2. Estimate emissions for our weight
166
  estimate = client.post(
@@ -169,29 +215,24 @@ def estimate_climatiq(
169
  "emission_factor": {"activity_id": activity_id, "data_version": data_version},
170
  "parameters": {"weight": weight_kg, "weight_unit": "kg"},
171
  },
172
- headers={"Authorization": f"Bearer {api_key}"},
173
  )
174
  if estimate.status_code != 200:
 
175
  return None
176
  payload: dict[str, Any] = estimate.json() or {}
177
  co2e_kg = payload.get("co2e")
178
  if not isinstance(co2e_kg, (int, float)):
179
  return None
180
 
181
- # Climatiq returns *emissions caused* by the disposal pathway. Our
182
- # convention is *kg avoided vs. BAU landfill*. We sign-flip: a
183
- # recycling factor of, say, +0.8 kg CO2e/kg becomes -0.8 in our
184
- # terms (recycling itself emits, but it AVOIDS landfill emissions
185
- # which are separately measured). The static table already encodes
186
- # the avoided-vs-BAU framing; the Climatiq path here surfaces the
187
- # raw emissions of the routing — caller should treat as a positive
188
- # footprint number, not a savings.
189
  basis = (
190
  f"climatiq:{activity_id} (data_version {data_version}) "
191
  f"× {weight_kg:.3f} kg via search '{query}'"
192
  )
193
  return float(co2e_kg), basis
194
  except Exception:
 
195
  return None
196
 
197
 
 
31
  from matter.passport import Environmental, Passport, Value
32
 
33
 
34
+ import threading
35
+
36
  CLIMATIQ_BASE_URL = "https://api.climatiq.io"
37
  CLIMATIQ_API_KEY_ENV = "CLIMATIQ_API_KEY"
38
+ CLIMATIQ_TIMEOUT_S = 3.0 # tight — caller wants fast inference, not perfect data
39
+ CLIMATIQ_BREAKER_THRESHOLD = 3 # consecutive failures before we stop trying
40
+
41
+ # Cache: free-text query → (activity_id, data_version). Populated on the first
42
+ # successful /search; subsequent calls for the same query skip /search entirely.
43
+ _climatiq_cache: dict[str, tuple[str, str]] = {}
44
+ _climatiq_cache_lock = threading.Lock()
45
+
46
+ # Per-process circuit breaker. If three consecutive Climatiq calls fail,
47
+ # stop trying for the rest of this process's lifetime — fall back to static.
48
+ # Reset to closed on any successful call.
49
+ _climatiq_breaker = {"failures": 0, "tripped": False}
50
+ _climatiq_breaker_lock = threading.Lock()
51
+
52
+
53
+ def _breaker_record_failure() -> None:
54
+ with _climatiq_breaker_lock:
55
+ _climatiq_breaker["failures"] += 1
56
+ if _climatiq_breaker["failures"] >= CLIMATIQ_BREAKER_THRESHOLD:
57
+ _climatiq_breaker["tripped"] = True
58
+
59
+
60
+ def _breaker_record_success() -> None:
61
+ with _climatiq_breaker_lock:
62
+ _climatiq_breaker["failures"] = 0
63
+ _climatiq_breaker["tripped"] = False
64
+
65
+
66
+ def _breaker_open() -> bool:
67
+ with _climatiq_breaker_lock:
68
+ return _climatiq_breaker["tripped"]
69
 
70
 
71
  @dataclass(frozen=True)
 
173
  query = _CLIMATIQ_QUERIES.get((head_name, identity_class, next_best_action))
174
  if query is None:
175
  return None
176
+ if _breaker_open():
177
+ # Circuit breaker tripped — don't waste time/budget on more API calls.
178
+ return None
179
  try:
180
  # Lazy import — httpx is in the package deps but optional at module load
181
  import httpx
182
  with httpx.Client(timeout=timeout) as client:
183
+ headers = {"Authorization": f"Bearer {api_key}"}
184
+
185
+ # 1. Resolve activity_id — check in-process cache first, hit /search
186
+ # only on miss.
187
+ with _climatiq_cache_lock:
188
+ cached = _climatiq_cache.get(query)
189
+ if cached is not None:
190
+ activity_id, data_version = cached
191
+ else:
192
+ search = client.get(
193
+ f"{CLIMATIQ_BASE_URL}/data/v1/search",
194
+ params={"query": query, "results_per_page": 1},
195
+ headers=headers,
196
+ )
197
+ if search.status_code != 200:
198
+ _breaker_record_failure()
199
+ return None
200
+ results = (search.json() or {}).get("results") or []
201
+ if not results:
202
+ return None
203
+ top = results[0]
204
+ activity_id = top.get("activity_id")
205
+ data_version = top.get("data_version", "^21")
206
+ if not activity_id:
207
+ return None
208
+ with _climatiq_cache_lock:
209
+ _climatiq_cache[query] = (activity_id, data_version)
210
 
211
  # 2. Estimate emissions for our weight
212
  estimate = client.post(
 
215
  "emission_factor": {"activity_id": activity_id, "data_version": data_version},
216
  "parameters": {"weight": weight_kg, "weight_unit": "kg"},
217
  },
218
+ headers=headers,
219
  )
220
  if estimate.status_code != 200:
221
+ _breaker_record_failure()
222
  return None
223
  payload: dict[str, Any] = estimate.json() or {}
224
  co2e_kg = payload.get("co2e")
225
  if not isinstance(co2e_kg, (int, float)):
226
  return None
227
 
228
+ _breaker_record_success()
 
 
 
 
 
 
 
229
  basis = (
230
  f"climatiq:{activity_id} (data_version {data_version}) "
231
  f"× {weight_kg:.3f} kg via search '{query}'"
232
  )
233
  return float(co2e_kg), basis
234
  except Exception:
235
+ _breaker_record_failure()
236
  return None
237
 
238
 
transformers_runtime.py CHANGED
@@ -11,6 +11,7 @@ the MATTER_MODEL_ID Space secret.
11
  from __future__ import annotations
12
 
13
  import os
 
14
  from pathlib import Path
15
  from typing import Literal
16
 
@@ -53,28 +54,42 @@ class TransformersRuntime:
53
  self.max_new_tokens = max_new_tokens
54
  self._model = None
55
  self._processor = None
 
 
 
 
56
 
57
  def _ensure_loaded(self) -> None:
 
58
  if self._model is not None:
59
  return
60
- from transformers import AutoModelForImageTextToText, AutoProcessor
61
-
62
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
63
- device = "cuda" if torch.cuda.is_available() else "cpu"
64
-
65
- self._processor = AutoProcessor.from_pretrained(self.model_id)
66
- self._model = AutoModelForImageTextToText.from_pretrained(
67
- self.model_id,
68
- torch_dtype=dtype,
69
- device_map=device,
70
- )
71
- if self.lora_id:
72
- try:
73
- from peft import PeftModel
74
- self._model = PeftModel.from_pretrained(self._model, self.lora_id)
75
- except Exception as e:
76
- print(f"[TransformersRuntime] LoRA load failed ({self.lora_id}): {e}")
77
- self._model.eval()
 
 
 
 
 
 
 
 
 
78
 
79
  def infer(self, prompt: str, image: Path | None) -> str:
80
  return self._infer_gpu(prompt, str(image) if image is not None else None)
 
11
  from __future__ import annotations
12
 
13
  import os
14
+ import threading
15
  from pathlib import Path
16
  from typing import Literal
17
 
 
54
  self.max_new_tokens = max_new_tokens
55
  self._model = None
56
  self._processor = None
57
+ # Guards _ensure_loaded against concurrent first-call races. Two users
58
+ # hitting a cold Space simultaneously could both enter `from_pretrained`
59
+ # without this lock and double-allocate, OOM'ing CUDA.
60
+ self._load_lock = threading.Lock()
61
 
62
  def _ensure_loaded(self) -> None:
63
+ # Fast path: already loaded, no lock needed.
64
  if self._model is not None:
65
  return
66
+ with self._load_lock:
67
+ # Double-checked locking: another thread may have completed the
68
+ # load while we were waiting for the lock.
69
+ if self._model is not None:
70
+ return
71
+ from transformers import AutoModelForImageTextToText, AutoProcessor
72
+
73
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
74
+ device = "cuda" if torch.cuda.is_available() else "cpu"
75
+
76
+ processor = AutoProcessor.from_pretrained(self.model_id)
77
+ model = AutoModelForImageTextToText.from_pretrained(
78
+ self.model_id,
79
+ torch_dtype=dtype,
80
+ device_map=device,
81
+ )
82
+ if self.lora_id:
83
+ try:
84
+ from peft import PeftModel
85
+ model = PeftModel.from_pretrained(model, self.lora_id)
86
+ except Exception as e:
87
+ print(f"[TransformersRuntime] LoRA load failed ({self.lora_id}): {e}")
88
+ model.eval()
89
+ # Publish atomically — readers without the lock should never see a
90
+ # half-initialized state.
91
+ self._processor = processor
92
+ self._model = model
93
 
94
  def infer(self, prompt: str, image: Path | None) -> str:
95
  return self._infer_gpu(prompt, str(image) if image is not None else None)