glitchfilter commited on
Commit
c6c7459
·
verified ·
1 Parent(s): 5769e72

Fix model inference + heuristic fallback

Browse files
Files changed (1) hide show
  1. server/app.py +117 -69
server/app.py CHANGED
@@ -125,51 +125,58 @@ AVAILABLE_MODELS = {
125
  "trl": {"id": "glitchfilter/methanol-apc-grpo-qwen2.5-3b", "label": "TRL GRPO (Qwen2.5-3B)"},
126
  }
127
 
128
- SYSTEM_PROMPT = (
129
  "You control a methanol synthesis reactor. Output a JSON object with these fields: "
130
  "feed_rate_h2 (0-10 mol/s), feed_rate_co (0-5 mol/s), cooling_water_flow (0-100 L/min), "
131
  "compressor_power (0-100 kW). The reactor is exothermic: 240-260C is optimal, >300C = shutdown. "
132
  "Maintain H2/CO ratio near 2.0. Revenue is $0.74/kg methanol."
133
  )
134
 
 
 
 
 
 
 
 
 
135
  def _load_model(model_key):
136
  """Lazy-load a LoRA adapter. Cached after first load."""
137
  if model_key in _loaded_models:
138
  return _loaded_models[model_key]
139
- try:
140
- import torch
141
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
142
- from peft import PeftModel
143
-
144
- info = AVAILABLE_MODELS[model_key]
145
- adapter_id = info["id"]
146
-
147
- # Determine base model from adapter_config
148
- from huggingface_hub import hf_hub_download
149
- import json
150
- cfg_path = hf_hub_download(adapter_id, "adapter_config.json")
151
- with open(cfg_path) as f:
152
- adapter_cfg = json.load(f)
153
- base_model_id = adapter_cfg.get("base_model_name_or_path", "Qwen/Qwen2.5-3B-Instruct")
154
-
155
- bnb = BitsAndBytesConfig(
156
- load_in_4bit=True, bnb_4bit_quant_type="nf4",
157
- bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True)
158
- base = AutoModelForCausalLM.from_pretrained(
159
- base_model_id, quantization_config=bnb, device_map="auto", trust_remote_code=True)
160
- model = PeftModel.from_pretrained(base, adapter_id)
161
- model.eval()
162
- tokenizer = AutoTokenizer.from_pretrained(adapter_id, trust_remote_code=True)
163
- if tokenizer.pad_token is None:
164
- tokenizer.pad_token = tokenizer.eos_token
165
- _loaded_models[model_key] = (model, tokenizer)
166
- return (model, tokenizer)
167
- except Exception as e:
168
- raise RuntimeError(f"Failed to load model {model_key}: {e}")
169
 
170
 
171
  def _obs_to_text(obs_dict):
172
- """Convert observation dict to compact sensor text for the model prompt."""
173
  parts = []
174
  for k in ["temperature", "pressure", "feed_rate_h2", "feed_rate_co", "h2_co_ratio",
175
  "cooling_water_flow", "catalyst_health", "reaction_rate", "methanol_produced",
@@ -183,57 +190,98 @@ def _obs_to_text(obs_dict):
183
  return " ".join(parts)
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  @app.get("/model/list")
187
  async def list_models():
188
- """Return available trained models."""
189
- return {"models": {k: v["label"] for k, v in AVAILABLE_MODELS.items()}}
190
 
 
191
 
192
  @app.post("/model/step")
193
- async def model_step(request):
194
- """Run one step using a trained model: load adapter, generate action, step env."""
195
  import json as _json
196
  body = await request.json()
197
  model_key = body.get("model", "trl")
198
  obs_dict = body.get("observation", {})
199
 
200
  if model_key not in AVAILABLE_MODELS:
201
- return {"error": f"Unknown model: {model_key}. Available: {list(AVAILABLE_MODELS.keys())}"}
202
-
203
- try:
204
- model, tokenizer = _load_model(model_key)
205
- except Exception as e:
206
- return {"error": f"Model load failed: {str(e)[:200]}"}
207
-
208
- # Build prompt
209
- sensor_text = _obs_to_text(obs_dict)
210
- messages = [
211
- {"role": "system", "content": SYSTEM_PROMPT},
212
- {"role": "user", "content": f"Sensors:\n{sensor_text}\n\nAction JSON:"},
213
- ]
214
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
215
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
216
-
217
- import torch
218
- with torch.no_grad():
219
- output = model.generate(
220
- **inputs, max_new_tokens=150, temperature=0.3,
221
- do_sample=True, pad_token_id=tokenizer.eos_token_id)
222
- response = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
223
-
224
- # Parse action from model response
225
- try:
226
- text = response.strip()
227
- start, end = text.find("{"), text.rfind("}") + 1
228
- action_dict = _json.loads(text[start:end])
229
- except Exception:
230
- action_dict = {"feed_rate_h2": 3.0, "feed_rate_co": 1.5,
231
- "cooling_water_flow": 60.0, "compressor_power": 50.0}
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  return {
234
  "action": action_dict,
235
- "raw_response": response[:300],
236
- "model": AVAILABLE_MODELS[model_key]["label"],
 
237
  }
238
 
239
 
 
125
  "trl": {"id": "glitchfilter/methanol-apc-grpo-qwen2.5-3b", "label": "TRL GRPO (Qwen2.5-3B)"},
126
  }
127
 
128
+ _MODEL_SYSTEM_PROMPT = (
129
  "You control a methanol synthesis reactor. Output a JSON object with these fields: "
130
  "feed_rate_h2 (0-10 mol/s), feed_rate_co (0-5 mol/s), cooling_water_flow (0-100 L/min), "
131
  "compressor_power (0-100 kW). The reactor is exothermic: 240-260C is optimal, >300C = shutdown. "
132
  "Maintain H2/CO ratio near 2.0. Revenue is $0.74/kg methanol."
133
  )
134
 
135
+ _GPU_AVAILABLE = False
136
+ try:
137
+ import torch as _torch
138
+ _GPU_AVAILABLE = _torch.cuda.is_available()
139
+ except ImportError:
140
+ pass
141
+
142
+
143
  def _load_model(model_key):
144
  """Lazy-load a LoRA adapter. Cached after first load."""
145
  if model_key in _loaded_models:
146
  return _loaded_models[model_key]
147
+
148
+ if not _GPU_AVAILABLE:
149
+ raise RuntimeError("No GPU available. Use pre-recorded mode or HF Inference API.")
150
+
151
+ import torch
152
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
153
+ from peft import PeftModel
154
+
155
+ info = AVAILABLE_MODELS[model_key]
156
+ adapter_id = info["id"]
157
+
158
+ from huggingface_hub import hf_hub_download
159
+ import json
160
+ cfg_path = hf_hub_download(adapter_id, "adapter_config.json")
161
+ with open(cfg_path) as f:
162
+ adapter_cfg = json.load(f)
163
+ base_model_id = adapter_cfg.get("base_model_name_or_path", "Qwen/Qwen2.5-3B-Instruct")
164
+
165
+ bnb = BitsAndBytesConfig(
166
+ load_in_4bit=True, bnb_4bit_quant_type="nf4",
167
+ bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True)
168
+ base = AutoModelForCausalLM.from_pretrained(
169
+ base_model_id, quantization_config=bnb, device_map="auto", trust_remote_code=True)
170
+ model = PeftModel.from_pretrained(base, adapter_id)
171
+ model.eval()
172
+ tokenizer = AutoTokenizer.from_pretrained(adapter_id, trust_remote_code=True)
173
+ if tokenizer.pad_token is None:
174
+ tokenizer.pad_token = tokenizer.eos_token
175
+ _loaded_models[model_key] = (model, tokenizer)
176
+ return (model, tokenizer)
177
 
178
 
179
  def _obs_to_text(obs_dict):
 
180
  parts = []
181
  for k in ["temperature", "pressure", "feed_rate_h2", "feed_rate_co", "h2_co_ratio",
182
  "cooling_water_flow", "catalyst_health", "reaction_rate", "methanol_produced",
 
190
  return " ".join(parts)
191
 
192
 
193
+ # Pre-recorded fallback actions (rule-based heuristic mimicking trained model)
194
+ def _heuristic_action(obs_dict):
195
+ """Generate a good action from observation using rule-based heuristic.
196
+ Used as fallback when GPU is not available."""
197
+ T = float(obs_dict.get("temperature", 250))
198
+ cat = float(obs_dict.get("catalyst_health", 1.0))
199
+
200
+ h2 = 5.0
201
+ co = 2.5
202
+ cool = 50.0
203
+ comp = 65.0
204
+
205
+ if T > 270:
206
+ h2 = max(2.0, h2 - (T - 270) * 0.3)
207
+ co = max(1.0, co - (T - 270) * 0.15)
208
+ cool = min(100.0, cool + (T - 270) * 3.0)
209
+ elif T < 240:
210
+ h2 = min(8.0, h2 + (240 - T) * 0.2)
211
+ co = min(4.0, co + (240 - T) * 0.1)
212
+ cool = max(10.0, cool - (240 - T) * 2.0)
213
+
214
+ if cat < 0.6:
215
+ h2 *= 0.8
216
+ co *= 0.8
217
+
218
+ return {
219
+ "feed_rate_h2": round(h2, 2),
220
+ "feed_rate_co": round(co, 2),
221
+ "cooling_water_flow": round(cool, 1),
222
+ "compressor_power": round(comp, 1),
223
+ }
224
+
225
+
226
  @app.get("/model/list")
227
  async def list_models():
228
+ return {"models": {k: v["label"] for k, v in AVAILABLE_MODELS.items()}, "gpu": _GPU_AVAILABLE}
229
+
230
 
231
+ from starlette.requests import Request as _Request
232
 
233
  @app.post("/model/step")
234
+ async def model_step(request: _Request):
 
235
  import json as _json
236
  body = await request.json()
237
  model_key = body.get("model", "trl")
238
  obs_dict = body.get("observation", {})
239
 
240
  if model_key not in AVAILABLE_MODELS:
241
+ return {"error": f"Unknown model: {model_key}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ # Try GPU inference first
244
+ if _GPU_AVAILABLE:
245
+ try:
246
+ model, tokenizer = _load_model(model_key)
247
+ sensor_text = _obs_to_text(obs_dict)
248
+ messages = [
249
+ {"role": "system", "content": _MODEL_SYSTEM_PROMPT},
250
+ {"role": "user", "content": f"Sensors:\n{sensor_text}\n\nAction JSON:"},
251
+ ]
252
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
253
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
254
+
255
+ import torch
256
+ with torch.no_grad():
257
+ output = model.generate(
258
+ **inputs, max_new_tokens=150, temperature=0.3,
259
+ do_sample=True, pad_token_id=tokenizer.eos_token_id)
260
+ response = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
261
+
262
+ try:
263
+ text = response.strip()
264
+ s, e = text.find("{"), text.rfind("}") + 1
265
+ action_dict = _json.loads(text[s:e])
266
+ except Exception:
267
+ action_dict = _heuristic_action(obs_dict)
268
+
269
+ return {
270
+ "action": action_dict,
271
+ "raw_response": response[:300],
272
+ "model": AVAILABLE_MODELS[model_key]["label"],
273
+ "mode": "gpu_inference",
274
+ }
275
+ except Exception as e:
276
+ _env_log.warning(f"GPU inference failed, falling back to heuristic: {e}")
277
+
278
+ # Fallback: rule-based heuristic (works everywhere, no GPU needed)
279
+ action_dict = _heuristic_action(obs_dict)
280
  return {
281
  "action": action_dict,
282
+ "raw_response": "heuristic fallback (no GPU)",
283
+ "model": AVAILABLE_MODELS[model_key]["label"] + " (heuristic)",
284
+ "mode": "heuristic_fallback",
285
  }
286
 
287