apsora commited on
Commit
b7c755d
·
verified ·
1 Parent(s): ba3fc56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1207 -49
app.py CHANGED
@@ -1,69 +1,1227 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
3
 
 
 
 
 
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
20
 
21
- messages.extend(history)
22
 
23
- messages.append({"role": "user", "content": message})
24
 
25
- response = ""
 
 
26
 
27
- for message in client.chat_completion(
 
 
 
 
 
 
 
 
28
  messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
41
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  ),
60
- ],
61
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
67
 
68
 
69
  if __name__ == "__main__":
 
1
+ import json
2
+ import re
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
  import gradio as gr
10
+ from transformers import AutoProcessor, AutoModelForVision2Seq
11
+ import spaces # <-- needed for Stateless GPU / zeroGPU
12
 
13
+ # ---------------------------------------------------------------------
14
+ # Minimal GPU-decorated function so Stateless GPU doesn't error out
15
+ # ---------------------------------------------------------------------
16
+ @spaces.GPU
17
+ def gpu_ping() -> str:
18
+ """
19
+ Dummy GPU endpoint so Hugging Face Stateless GPU / zeroGPU
20
+ detects at least one @spaces.GPU function.
21
 
22
+ We don't actually use this in the app logic. It just keeps
23
+ the Space from throwing:
24
+ 'No @spaces.GPU function detected during startup'.
 
 
 
 
 
 
25
  """
26
+ return "gpu_ready"
27
+
28
+
29
+ # ============================================================
30
+ # 0. Model + guidelines setup
31
+ # ============================================================
32
+
33
+ # NOTE: we keep everything on CPU here to avoid touching CUDA
34
+ # in the main process (required for Stateless GPU).
35
+ DEVICE = "cpu"
36
+ DTYPE = torch.float32
37
+
38
+ MODEL_NAME = "maryzhang/qwen3vl-guideline-lora-model"
39
+
40
+ print(f"Loading unified vision+text model {MODEL_NAME} on {DEVICE}", flush=True)
41
+
42
+ model_vlm = AutoModelForVision2Seq.from_pretrained(
43
+ MODEL_NAME,
44
+ dtype=DTYPE,
45
+ trust_remote_code=True,
46
+ )
47
+ model_vlm.to(DEVICE)
48
+ model_vlm.eval()
49
+
50
+ processor_vlm = AutoProcessor.from_pretrained(
51
+ MODEL_NAME,
52
+ trust_remote_code=True,
53
+ )
54
+
55
+ GUIDELINES_PATH = "guidelines_final.json"
56
+
57
+
58
+ def load_guidelines(path: str) -> List[Dict[str, Any]]:
59
+ """
60
+ Robust loader for guidelines_final.json.
61
+ Accepts:
62
+ - a big sequence of JSON objects (your current format)
63
+ - or a single list
64
+ - or {"guidelines": [...]}
65
+ Returns flat list of dicts that contain "guideline_id".
66
  """
67
+ with open(path, "r") as f:
68
+ raw = f.read()
69
+
70
+ raw = raw.strip()
71
+ if not raw:
72
+ raise ValueError("guidelines_final.json is empty.")
73
+
74
+ decoder = json.JSONDecoder()
75
+ pos = 0
76
+ length = len(raw)
77
+ objects: List[Any] = []
78
+
79
+ # collect all JSON fragments
80
+ while pos < length:
81
+ while pos < length and raw[pos].isspace():
82
+ pos += 1
83
+ if pos >= length:
84
+ break
85
+ try:
86
+ obj, end = decoder.raw_decode(raw, pos)
87
+ except json.JSONDecodeError:
88
+ pos += 1
89
+ continue
90
+ objects.append(obj)
91
+ pos = end
92
+
93
+ if not objects:
94
+ raise ValueError("No JSON fragments found in guidelines_final.json")
95
+
96
+ candidates: List[Any] = []
97
+ for obj in objects:
98
+ if isinstance(obj, list):
99
+ candidates.extend(obj)
100
+ elif isinstance(obj, dict) and isinstance(obj.get("guidelines"), list):
101
+ candidates.extend(obj["guidelines"])
102
+ elif isinstance(obj, dict):
103
+ candidates.append(obj)
104
+
105
+ guidelines: List[Dict[str, Any]] = []
106
+ for c in candidates:
107
+ if isinstance(c, dict) and "guideline_id" in c:
108
+ guidelines.append(c)
109
+
110
+ if not guidelines:
111
+ raise ValueError("Found JSON but no objects with 'guideline_id' field.")
112
+ return guidelines
113
+
114
 
115
+ ALL_GUIDELINES: List[Dict[str, Any]] = load_guidelines(GUIDELINES_PATH)
116
+ GUIDELINE_BY_ID: Dict[str, Dict[str, Any]] = {g["guideline_id"]: g for g in ALL_GUIDELINES}
117
 
118
+ print(f"Loaded {len(ALL_GUIDELINES)} guidelines", flush=True)
119
 
 
120
 
121
+ # ============================================================
122
+ # 1. Core LLM helpers (text-only + vision)
123
+ # ============================================================
124
 
125
+ def run_text_llm(system_prompt: str, user_prompt: str, max_new_tokens: int = 768) -> str:
126
+ """
127
+ Use Qwen3-VL (LoRA) in text-only mode.
128
+ """
129
+ messages = [
130
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
131
+ {"role": "user", "content": [{"type": "text", "text": user_prompt}]},
132
+ ]
133
+ prompt_text = processor_vlm.apply_chat_template(
134
  messages,
135
+ tokenize=False,
136
+ add_generation_prompt=True,
137
+ )
 
 
 
 
 
 
138
 
139
+ inputs = processor_vlm(
140
+ text=prompt_text,
141
+ return_tensors="pt",
142
+ ).to(DEVICE)
143
 
144
+ with torch.no_grad():
145
+ output_ids = model_vlm.generate(
146
+ **inputs,
147
+ max_new_tokens=max_new_tokens,
148
+ temperature=0.0,
149
+ do_sample=False,
150
+ )
151
 
152
+ generated = processor_vlm.decode(
153
+ output_ids[0],
154
+ skip_special_tokens=True,
155
+ ).strip()
156
+ return generated
157
+
158
+
159
+ def vlm_generate_json_from_images(
160
+ prompt: str,
161
+ images: List[Image.Image],
162
+ ) -> Dict[str, Any]:
163
+ """
164
+ Call Qwen3-VL with images and ask it to return STRICT JSON.
165
+ """
166
+ if not images:
167
+ images = [Image.new("RGB", (64, 64), "white")]
168
+
169
+ content = [{"type": "image"} for _ in images]
170
+ content.append({"type": "text", "text": prompt})
171
+
172
+ messages = [{"role": "user", "content": content}]
173
+
174
+ prompt_text = processor_vlm.apply_chat_template(
175
+ messages,
176
+ tokenize=False,
177
+ add_generation_prompt=True,
178
+ )
179
+
180
+ inputs = processor_vlm(
181
+ text=prompt_text,
182
+ images=images,
183
+ return_tensors="pt",
184
+ ).to(DEVICE)
185
+
186
+ with torch.no_grad():
187
+ output_ids = model_vlm.generate(
188
+ **inputs,
189
+ max_new_tokens=512,
190
+ temperature=0.0,
191
+ do_sample=False,
192
+ )
193
+
194
+ generated = processor_vlm.decode(
195
+ output_ids[0],
196
+ skip_special_tokens=True,
197
+ ).strip()
198
+
199
+ m = re.search(r"\{.*\}", generated, re.DOTALL)
200
+ if m:
201
+ try:
202
+ return json.loads(m.group(0))
203
+ except Exception:
204
+ pass
205
+ return {"parse_error": True, "raw": generated}
206
+
207
+
208
+ # ============================================================
209
+ # 2. Feature extraction & guideline selection
210
+ # ============================================================
211
+
212
+ FEATURE_PROMPT = """
213
+ You are assisting with manufacturability and GD&T review.
214
+ Given these 1–3 CAD / drawing images, return a JSON object with:
215
+ {
216
+ "image_type": "cad_model" | "dimensioned_drawing" | "photo" | "other",
217
+ "has_gdt": bool,
218
+ "has_dimensions": bool,
219
+ "features": {
220
+ "holes": int,
221
+ "vertical_faces": bool,
222
+ "possible_draft": bool,
223
+ "ribs": int,
224
+ "fillets": bool,
225
+ "chamfers": bool,
226
+ "datum_symbols": ["A", "B"],
227
+ "gdt_frames_present": bool,
228
+ "text_dimensions_present": bool
229
+ },
230
+ "raw_notes": "short human-readable notes about what you see",
231
+ "generated_description": "one-sentence description of the part/drawing",
232
+ "suggested_guidelines": []
233
+ }
234
+ Rules:
235
+ - Infer only what is visible or strongly implied.
236
+ - Keep numbers rough (e.g., count of holes), not exact metrology.
237
+ - Only output valid JSON. No explanation outside the JSON.
238
+ - Do NOT hard-code any specific guideline IDs.
239
  """
240
+
241
+
242
+ def extract_visual_features(images: List[Image.Image]) -> Dict[str, Any]:
243
+ if not images:
244
+ return {
245
+ "image_type": "",
246
+ "has_gdt": False,
247
+ "has_dimensions": False,
248
+ "features": {
249
+ "holes": 0,
250
+ "vertical_faces": False,
251
+ "possible_draft": False,
252
+ "ribs": 0,
253
+ "fillets": False,
254
+ "chamfers": False,
255
+ "datum_symbols": [],
256
+ "gdt_frames_present": False,
257
+ "text_dimensions_present": False,
258
+ },
259
+ "raw_notes": "",
260
+ "generated_description": "",
261
+ "suggested_guidelines": [],
262
+ }
263
+
264
+ vlm_json = vlm_generate_json_from_images(FEATURE_PROMPT, images)
265
+
266
+ return {
267
+ "image_type": vlm_json.get("image_type", ""),
268
+ "has_gdt": vlm_json.get("has_gdt", False),
269
+ "has_dimensions": vlm_json.get("has_dimensions", False),
270
+ "features": vlm_json.get("features", {}),
271
+ "raw_notes": vlm_json.get("raw_notes", ""),
272
+ "generated_description": vlm_json.get("generated_description", ""),
273
+ "suggested_guidelines": vlm_json.get("suggested_guidelines", []),
274
+ }
275
+
276
+
277
+ def rag_retrieve(query: str, top_k: int = 6) -> List[Dict[str, Any]]:
278
+ """
279
+ Tiny RAG over the 20 guidelines.
280
+ Now also includes pass_fail_logic in the searchable blob so the
281
+ evaluator can "see" the numeric rules.
282
+ """
283
+ q = (query or "").lower()
284
+ if not q.strip():
285
+ return []
286
+
287
+ scored = []
288
+ for g in ALL_GUIDELINES:
289
+ pfl = g.get("pass_fail_logic") or {}
290
+ pfl_text = " ".join(
291
+ f"{k}: {v}" for k, v in pfl.items()
292
+ )
293
+ blob = " ".join(
294
+ [
295
+ g.get("topic", ""),
296
+ " ".join(g.get("evaluation_criteria", []) or []),
297
+ " ".join(g.get("expected_answers", []) or []),
298
+ pfl_text,
299
+ ]
300
+ ).lower()
301
+ score = sum(token in blob for token in q.split())
302
+ if score > 0:
303
+ scored.append((score, g))
304
+
305
+ scored.sort(key=lambda x: x[0], reverse=True)
306
+ hits = []
307
+ for score, g in scored[:top_k]:
308
+ pfl = g.get("pass_fail_logic") or {}
309
+ pfl_text = " ".join(
310
+ f"{k}: {v}" for k, v in pfl.items()
311
+ )
312
+ text = (
313
+ " ".join(g.get("evaluation_criteria", []) or [])
314
+ or " ".join(g.get("expected_answers", []) or [])
315
+ or pfl_text
316
+ )
317
+ hits.append(
318
+ {
319
+ "source": "guideline",
320
+ "text": text,
321
+ "meta": {
322
+ "guideline_id": g["guideline_id"],
323
+ "topic": g.get("topic", ""),
324
+ },
325
+ }
326
+ )
327
+ return hits
328
+
329
+
330
+ def classify_mode(description: str, feature_summary: Dict[str, Any]) -> str:
331
+ desc_lower = (description or "").lower()
332
+ feats = feature_summary.get("features", {})
333
+
334
+ image_type = (feature_summary.get("image_type") or "").lower()
335
+ has_gdt_flag = bool(feature_summary.get("has_gdt"))
336
+ has_dims_flag = bool(feature_summary.get("has_dimensions"))
337
+
338
+ has_datum = bool(feats.get("datum_symbols"))
339
+ has_gdt_feat = feats.get("gdt_frames_present", False)
340
+
341
+ cad_like_words = ["cad", "model", "solid", "surface", "bottle", "housing", "rib"]
342
+ drawing_like_words = ["drawing", "dimension", "tolerance"]
343
+
344
+ has_cad_words = any(w in desc_lower for w in cad_like_words)
345
+ has_drawing_words = any(w in desc_lower for w in drawing_like_words)
346
+
347
+ gd_signals = any(
348
+ [
349
+ image_type == "dimensioned_drawing",
350
+ has_gdt_flag,
351
+ has_gdt_feat,
352
+ has_datum,
353
+ has_dims_flag,
354
+ has_drawing_words,
355
+ ]
356
+ )
357
+ cad_signals = any(
358
+ [
359
+ image_type == "cad_model",
360
+ has_cad_words,
361
+ ]
362
+ )
363
+
364
+ if gd_signals and cad_signals:
365
+ return "mixed"
366
+ if gd_signals:
367
+ return "gdt"
368
+ if cad_signals:
369
+ return "dfm"
370
+ return "dfm"
371
+
372
+
373
+ def select_applicable_guidelines(
374
+ feature_summary: Dict[str, Any],
375
+ description: str,
376
+ max_guidelines: int = 5,
377
+ ) -> List[Dict[str, Any]]:
378
+ """
379
+ Choose a subset of guidelines out of all 20, based on dfm/gdt mode.
380
+ Returns lightweight dicts (guideline_id + topic), but the evaluator
381
+ will later look up the full objects from GUIDELINE_BY_ID.
382
+ """
383
+ mode = classify_mode(description, feature_summary)
384
+ suggestions = feature_summary.get("suggested_guidelines") or []
385
+
386
+ def category_of(g: Dict[str, Any]) -> str:
387
+ cat = (g.get("category") or "").lower()
388
+ if cat in ("dfm", "gdt"):
389
+ return cat
390
+ gid = (g.get("guideline_id") or "").upper()
391
+ if gid.startswith("D"):
392
+ return "dfm"
393
+ if gid.startswith("G"):
394
+ return "gdt"
395
+ return ""
396
+
397
+ picked: List[Dict[str, Any]] = []
398
+ suggested_ids = set()
399
+
400
+ # 1) honour any suggested_guidelines (if they match the mode)
401
+ for s in suggestions:
402
+ gid = s.get("guideline_id")
403
+ if not gid:
404
+ continue
405
+ g = GUIDELINE_BY_ID.get(gid)
406
+ if not g:
407
+ continue
408
+ cat = category_of(g)
409
+ if mode == "gdt" and cat != "gdt":
410
+ continue
411
+ if mode == "dfm" and cat != "dfm":
412
+ continue
413
+ picked.append({"guideline_id": gid, "topic": g.get("topic", "")})
414
+ suggested_ids.add(gid)
415
+
416
+ # 2) fill in from ALL_GUIDELINES based on mode
417
+ for g in ALL_GUIDELINES:
418
+ gid = g["guideline_id"]
419
+ if gid in suggested_ids:
420
+ continue
421
+ cat = category_of(g)
422
+ if mode == "gdt" and cat == "gdt":
423
+ picked.append({"guideline_id": gid, "topic": g["topic"]})
424
+ elif mode == "dfm" and cat == "dfm":
425
+ picked.append({"guideline_id": gid, "topic": g["topic"]})
426
+ elif mode == "mixed" and cat in ("gdt", "dfm"):
427
+ picked.append({"guideline_id": gid, "topic": g["topic"]})
428
+
429
+ # 3) in mixed mode, bias GD&T first
430
+ if mode == "mixed":
431
+ def is_gdt(gid: str) -> bool:
432
+ g = GUIDELINE_BY_ID.get(gid, {})
433
+ return category_of(g) == "gdt"
434
+
435
+ picked.sort(key=lambda x: 0 if is_gdt(x["guideline_id"]) else 1)
436
+
437
+ return picked[:max_guidelines]
438
+
439
+
440
+ # ============================================================
441
+ # 3. Evaluation utilities
442
+ # ============================================================
443
+
444
+ def extract_json_from_text(text: str) -> Dict[str, Any]:
445
+ m = re.search(r"\{.*\}", text, re.DOTALL)
446
+ if not m:
447
+ return {"parse_error": True, "raw": text}
448
+ try:
449
+ return json.loads(m.group(0))
450
+ except Exception:
451
+ return {"parse_error": True, "raw": text}
452
+
453
+
454
+ def downgrade_if_no_measurements(
455
+ eval_json: Dict[str, Any],
456
+ qa_text: str,
457
+ ) -> Dict[str, Any]:
458
+ q_lower = (qa_text or "").lower()
459
+ no_data = any(
460
+ phrase in q_lower
461
+ for phrase in [
462
+ "no measurement data",
463
+ "no measured data",
464
+ "assume 0 mm",
465
+ "assume zero",
466
+ "no cmm data",
467
+ ]
468
+ )
469
+ if not no_data:
470
+ return eval_json
471
+
472
+ sensitive_topics = [
473
+ "True Position",
474
+ "Profile",
475
+ "Flatness",
476
+ "Concentricity",
477
+ "Runout",
478
+ "Cylindricity",
479
+ "Circularity",
480
+ ]
481
+
482
+ for g in eval_json.get("guidelines", []):
483
+ topic = g.get("topic", "")
484
+ if any(t in topic for t in sensitive_topics):
485
+ g["result"] = "NEEDS_INFO"
486
+ g["reason"] = (
487
+ "This guideline depends on measurement data, and you mentioned that "
488
+ "measurements are not available yet. That's completely fine at the "
489
+ "design stage, so this is marked as NEEDS_INFO rather than PASS/FAIL."
490
+ )
491
+ g["recommendation"] = (
492
+ "Once you have inspection or simulation data, you can re-run this check "
493
+ "to confirm the tolerance is still realistic."
494
+ )
495
+
496
+ return eval_json
497
+
498
+
499
+ def calibrate_eval_scores(eval_json: Dict[str, Any]) -> Dict[str, Any]:
500
+ guidelines = eval_json.get("guidelines", [])
501
+ eval_json.setdefault("overall", {})
502
+
503
+ if not guidelines:
504
+ eval_json["overall"].update(
505
+ {
506
+ "summary": "No guidelines were evaluated.",
507
+ "verdict": "NEEDS_MORE_DATA",
508
+ "manufacturability_score": 0.6,
509
+ }
510
+ )
511
+ return eval_json
512
+
513
+ weights = {"PASS": 1.0, "NEEDS_INFO": 0.7, "FAIL": 0.0}
514
+ results = [g.get("result", "NEEDS_INFO") for g in guidelines]
515
+
516
+ if all(r == "NEEDS_INFO" for r in results):
517
+ eval_json["overall"].update(
518
+ {
519
+ "summary": (
520
+ "All guidelines are marked as NEEDS_INFO for now because some data "
521
+ "is missing. That's okay—this just means more information will make "
522
+ "the review stronger later."
523
+ ),
524
+ "verdict": "NEEDS_MORE_DATA",
525
+ "manufacturability_score": 0.65,
526
+ }
527
+ )
528
+ return eval_json
529
+
530
+ scores = [weights.get(r, 0.7) for r in results]
531
+ avg = sum(scores) / len(scores)
532
+
533
+ if avg > 0.9:
534
+ verdict = "GOOD"
535
+ elif avg > 0.75:
536
+ verdict = "ACCEPTABLE"
537
+ elif avg > 0.6:
538
+ verdict = "RISKY"
539
+ else:
540
+ verdict = "NEEDS_MORE_DATA"
541
+
542
+ eval_json["overall"].update(
543
+ {
544
+ "summary": (
545
+ "Automatic manufacturability summary based on the "
546
+ "reviewed guidelines."
547
+ ),
548
+ "verdict": verdict,
549
+ "manufacturability_score": round(float(avg), 2),
550
+ }
551
+ )
552
+ return eval_json
553
+
554
+
555
+ def sanitize_eval_language(
556
+ eval_json: Dict[str, Any],
557
+ description: str,
558
+ feature_summary: Dict[str, Any],
559
+ ) -> Dict[str, Any]:
560
+ desc_lower = (description or "").lower()
561
+ feats = feature_summary.get("features", {})
562
+
563
+ is_machined = any(
564
+ w in desc_lower for w in ["machined", "cnc", "turned", "lathe", "ground"]
565
+ )
566
+ is_molded_like = feats.get("possible_draft", False) or any(
567
+ w in desc_lower for w in ["mold", "mould", "injection", "cast", "die cast"]
568
+ )
569
+
570
+ guideline_explanations = {
571
+ "True Position Tolerance": (
572
+ "True position helps ensure that holes or pins line up correctly in "
573
+ "assembly, so parts fit together without binding or excessive play."
574
  ),
575
+ "Profile Tolerance": (
576
+ "Profile controls how closely a surface matches its ideal CAD shape. "
577
+ "This matters a lot for sealing, smooth airflow, and consistent contact."
578
+ ),
579
+ "Flatness": (
580
+ "Flatness makes sure a surface does not bow or warp, which is important "
581
+ "for good sealing and accurate mounting faces."
582
+ ),
583
+ "Concentricity": (
584
+ "Concentricity ensures that different cylindrical features share the same "
585
+ "axis. This is crucial for rotating parts, shafts, and precision fits."
586
+ ),
587
+ }
588
+
589
+ encouraging_phrases = {
590
+ "PASS": (
591
+ "Nice work—this guideline looks solid. If you want to go further, you "
592
+ "could explore tolerance stack-ups or measurement planning for production."
593
+ ),
594
+ "NEEDS_INFO": (
595
+ "This isn’t a failure—it just means more information (like measurements "
596
+ "or simulation results) would help finish the story."
597
+ ),
598
+ "FAIL": (
599
+ "This might cause manufacturability or inspection challenges, but it's a "
600
+ "great opportunity to iterate and improve the design early."
601
+ ),
602
+ }
603
+
604
+ for g in eval_json.get("guidelines", []):
605
+ topic = g.get("topic", "")
606
+ result = g.get("result", "NEEDS_INFO")
607
+
608
+ if topic in guideline_explanations:
609
+ g["why_it_matters"] = guideline_explanations[topic]
610
+
611
+ g.setdefault("recommendation", "")
612
+ g["recommendation"] = (g["recommendation"] or "").strip()
613
+ extra = encouraging_phrases.get(result)
614
+ if extra:
615
+ if g["recommendation"]:
616
+ g["recommendation"] += " "
617
+ g["recommendation"] += extra
618
+
619
+ # clean out weird generic ranges / hole size hallucinations
620
+ for key in ["reason", "recommendation"]:
621
+ text = g.get(key, "")
622
+ if not isinstance(text, str):
623
+ continue
624
+
625
+ sentences = re.split(r"(?<=[.!?])\s+", text)
626
+ cleaned_sents = []
627
+ for s in sentences:
628
+ s_lower = s.lower()
629
+ if (
630
+ "typical range" in s_lower
631
+ or "small holes" in s_lower
632
+ or "< 5 mm" in s_lower
633
+ or "less than 5 mm" in s_lower
634
+ ):
635
+ continue
636
+ cleaned_sents.append(s)
637
+
638
+ new_text = " ".join(cleaned_sents).strip()
639
+
640
+ if is_machined and not is_molded_like:
641
+ new_text = (
642
+ new_text.replace(
643
+ "molding process capabilities",
644
+ "machining process capabilities",
645
+ )
646
+ .replace("molding process capability", "machining process capability")
647
+ .replace("molding process", "machining process")
648
+ )
649
+
650
+ g[key] = new_text
651
+
652
+ overall = eval_json.get("overall", {})
653
+ if overall.get("verdict") == "POOR":
654
+ overall["verdict"] = "NEEDS_MORE_DATA"
655
+ overall["summary"] = (
656
+ "Some guidelines look challenging with the current information, but that "
657
+ "just means there is room to refine the design and collect more data."
658
+ )
659
+ eval_json["overall"] = overall
660
+ return eval_json
661
+
662
+
663
+ def evaluation_agent_txt(
664
+ description: str,
665
+ guidelines: List[Dict[str, Any]],
666
+ qa_text: str,
667
+ feature_summary: Dict[str, Any],
668
+ ) -> Dict[str, Any]:
669
+ """
670
+ Core evaluator: this is where we now pass in:
671
+ - evaluation_criteria
672
+ - expected_answers
673
+ - pass_fail_logic
674
+ for EACH guideline, so the model can truly reason over your 20 rules.
675
+ """
676
+ # Enrich guideline objects from the global GUIDELINE_BY_ID
677
+ enriched_guidelines = []
678
+ for g in guidelines:
679
+ gid = g.get("guideline_id")
680
+ base = GUIDELINE_BY_ID.get(gid, {})
681
+ enriched_guidelines.append(
682
+ {
683
+ "guideline_id": gid,
684
+ "topic": base.get("topic", g.get("topic", "")),
685
+ "category": base.get("category", ""),
686
+ "evaluation_criteria": base.get("evaluation_criteria", []),
687
+ "user_questions": base.get("user_questions", []),
688
+ "expected_answers": base.get("expected_answers", []),
689
+ "pass_fail_logic": base.get("pass_fail_logic", {}),
690
+ }
691
+ )
692
+
693
+ rag_query_text = " ".join(
694
+ [
695
+ description or "",
696
+ qa_text or "",
697
+ json.dumps(feature_summary.get("features", {})),
698
+ ]
699
+ )
700
+ rag_hits = rag_retrieve(rag_query_text, top_k=6)
701
+
702
+ rag_context_lines = []
703
+ for h in rag_hits:
704
+ meta = h.get("meta", {})
705
+ gid = meta.get("guideline_id", "UNKNOWN")
706
+ topic = meta.get("topic", "")
707
+ rag_context_lines.append(f"[GUIDELINE {gid} - {topic}]\n{h['text']}")
708
+ rag_context = (
709
+ "\n\n---\n\n".join(rag_context_lines)
710
+ if rag_context_lines
711
+ else "(no extra context)"
712
+ )
713
+
714
+ sys_prompt = (
715
+ "You are a senior manufacturing / GD&T engineer and a patient instructor.\n"
716
+ "You are given:\n"
717
+ "- An optional short description of the part/product\n"
718
+ "- A set of DFM/GD&T guidelines to apply (including evaluation_criteria,\n"
719
+ " expected_answers, and pass_fail_logic for each guideline)\n"
720
+ "- A Q&A history where the student answered questions about each guideline\n"
721
+ "- A feature summary extracted from CAD/drawing images\n"
722
+ "- Additional reference passages from a guideline knowledge base (RAG)\n\n"
723
+ "Your goals:\n"
724
+ "1) For EACH guideline, use the student's numeric/text answers and the\n"
725
+ " 'pass_fail_logic' rules to decide whether the guideline is PASS, FAIL,\n"
726
+ " or NEEDS_INFO.\n"
727
+ " • PASS = clearly satisfies the numeric / logical rules.\n"
728
+ " • FAIL = clearly violates at least one rule in pass_fail_logic.\n"
729
+ " • NEEDS_INFO = only if you truly cannot tell from the Q&A + features.\n"
730
+ "2) Refer directly to the variables in pass_fail_logic (e.g., nominal_wall,\n"
731
+ " variation, rib_or_boss_thickness) and the numbers in the Q&A when\n"
732
+ " making decisions. Treat the rules as engineering check equations.\n"
733
+ "3) Explain briefly WHY in clear engineering language.\n"
734
+ "4) Offer encouraging, actionable recommendations—talk like a helpful TA.\n"
735
+ "5) Comment qualitatively on tolerance feasibility in the 'overall' block.\n\n"
736
+ "IMPORTANT:\n"
737
+ "- You MUST try to produce PASS or FAIL when the numeric conditions are\n"
738
+ " clearly satisfied or violated. Do NOT default to NEEDS_INFO if the\n"
739
+ " student already provided the key numbers.\n"
740
+ "- Only use NEEDS_INFO when the data is genuinely missing or ambiguous.\n\n"
741
+ "Respond ONLY as a single JSON object with this schema:\n"
742
+ "{\n"
743
+ ' "guidelines": [\n'
744
+ " {\n"
745
+ ' "guideline_id": str,\n'
746
+ ' "topic": str,\n'
747
+ ' "result": "PASS" | "FAIL" | "NEEDS_INFO",\n'
748
+ ' "reason": str,\n'
749
+ ' "recommendation": str\n'
750
+ " }\n"
751
+ " ],\n"
752
+ ' "overall": {\n'
753
+ ' "summary": str,\n'
754
+ ' "verdict": "GOOD" | "ACCEPTABLE" | "RISKY" | "NEEDS_MORE_DATA",\n'
755
+ ' "manufacturability_score": float\n'
756
+ " }\n"
757
+ "}\n"
758
+ )
759
+
760
+ user_parts = [
761
+ "DESCRIPTION:",
762
+ description or "(none provided)",
763
+ "\n\nGUIDELINES UNDER REVIEW (with criteria and logic):",
764
+ json.dumps(enriched_guidelines, indent=2),
765
+ "\n\nQ&A HISTORY (questions and answers as free text):",
766
+ qa_text or "(no questions asked yet)",
767
+ "\n\nFEATURE SUMMARY FROM IMAGE(S):",
768
+ json.dumps(feature_summary, indent=2),
769
+ "\n\nRETRIEVED REFERENCES (RAG):",
770
+ rag_context,
771
+ "\n\nProduce ONLY the JSON object.",
772
+ ]
773
+ user_prompt = "\n".join(user_parts)
774
+
775
+ raw = run_text_llm(sys_prompt, user_prompt, max_new_tokens=1024)
776
+ eval_json = extract_json_from_text(raw)
777
+
778
+ if not eval_json.get("parse_error"):
779
+ eval_json = downgrade_if_no_measurements(eval_json, qa_text)
780
+ eval_json = calibrate_eval_scores(eval_json)
781
+ eval_json = sanitize_eval_language(eval_json, description, feature_summary)
782
+ return eval_json
783
+
784
+
785
+ def summarize_eval_for_student(eval_json: Dict[str, Any]) -> str:
786
+ guidelines = eval_json.get("guidelines", [])
787
+ overall = eval_json.get("overall", {})
788
+
789
+ lines: List[str] = []
790
+ lines.append(
791
+ "Thanks, that’s all the questions I needed for now. "
792
+ "Here’s your manufacturability snapshot based on those answers:"
793
+ )
794
+ lines.append("")
795
+
796
+ score = overall.get("manufacturability_score")
797
+ verdict = overall.get("verdict")
798
+ summary = overall.get("summary", "")
799
+
800
+ if score is not None or verdict:
801
+ headline = "• Overall verdict: "
802
+ if verdict:
803
+ headline += str(verdict)
804
+ if score is not None:
805
+ headline += f" (score ≈ {score:.2f})"
806
+ lines.append(headline)
807
+
808
+ if summary:
809
+ lines.append(f"• Summary: {summary}")
810
+ lines.append("")
811
+
812
+ if guidelines:
813
+ lines.append("Guideline-by-guideline notes:")
814
+ for g in guidelines:
815
+ topic = g.get("topic", "Unnamed guideline")
816
+ result = g.get("result", "NEEDS_INFO")
817
+ reason = g.get("reason", "")
818
+ rec = g.get("recommendation", "")
819
+ lines.append(f"- {topic} → {result}")
820
+ if reason:
821
+ lines.append(f" • Why: {reason}")
822
+ if rec:
823
+ lines.append(f" • Suggestion: {rec}")
824
+ else:
825
+ lines.append(
826
+ "I wasn’t able to evaluate any specific guidelines, likely because "
827
+ "we didn’t get enough structured answers."
828
+ )
829
+
830
+ lines.append("")
831
+ lines.append(
832
+ "If you’d like to see the raw JSON data for debugging or research, "
833
+ "you can ask: “show me the JSON summary.”"
834
+ )
835
+ return "\n".join(lines)
836
+
837
+
838
+ # ============================================================
839
+ # 4. Conversation state & router
840
+ # ============================================================
841
+
842
+ @dataclass
843
+ class GuidelineConversationState:
844
+ selected_guidelines: List[Dict[str, Any]] = field(default_factory=list)
845
+ current_guideline_idx: int = 0
846
+ qa_log: List[Tuple[str, str]] = field(default_factory=list)
847
+ max_questions: int = 8
848
+ questions_asked: int = 0
849
+ feature_summary: Dict[str, Any] = field(default_factory=dict)
850
+ description: str = ""
851
+
852
+
853
+ def current_guideline(
854
+ state: GuidelineConversationState,
855
+ ) -> Optional[Dict[str, Any]]:
856
+ if 0 <= state.current_guideline_idx < len(state.selected_guidelines):
857
+ return state.selected_guidelines[state.current_guideline_idx]
858
+ return None
859
+
860
+
861
+ def build_intro_message(
862
+ description: str,
863
+ feature_summary: Dict[str, Any],
864
+ selected_guidelines: List[Dict[str, Any]],
865
+ max_questions: int,
866
+ ) -> str:
867
+ gen_desc = feature_summary.get("generated_description") or ""
868
+ raw_notes = feature_summary.get("raw_notes") or ""
869
+
870
+ desc_bits = []
871
+ if gen_desc:
872
+ desc_bits.append(gen_desc)
873
+ if description:
874
+ desc_bits.append(description)
875
+ if raw_notes:
876
+ desc_bits.append(raw_notes)
877
+
878
+ combined_desc = (
879
+ " ".join(desc_bits)
880
+ if desc_bits
881
+ else "I’ll infer as much as I can directly from your image."
882
+ )
883
+
884
+ guideline_topics = [g["topic"] for g in selected_guidelines]
885
+ guideline_list_str = (
886
+ ", ".join(guideline_topics)
887
+ if guideline_topics
888
+ else "a small set of relevant DFM/GD&T rules"
889
+ )
890
+
891
+ intro = (
892
+ f"{combined_desc}\n\n"
893
+ "Based on this, I’ll walk you through a short manufacturability review.\n"
894
+ f"We’ll look at these guidelines: {guideline_list_str}.\n"
895
+ "I’ll ask at most ~"
896
+ f"{max_questions} focused questions, and then summarize how "
897
+ "manufacturable this design looks and where you could improve it.\n\n"
898
+ "Let’s start with the first guideline."
899
+ )
900
+ return intro
901
+
902
+
903
+ def get_guideline_questions(gid: str) -> List[str]:
904
+ g = GUIDELINE_BY_ID.get(gid)
905
+ if not g:
906
+ return []
907
+ qs = g.get("user_questions") or g.get("questions") or []
908
+ out = []
909
+ for q in qs:
910
+ if isinstance(q, str):
911
+ out.append(q)
912
+ elif isinstance(q, dict) and "question" in q:
913
+ out.append(q["question"])
914
+ return out
915
+
916
+
917
+ def classify_user_turn(user_text: str, last_question: str) -> str:
918
+ """
919
+ Tiny router: is the user answering the guideline question,
920
+ or asking their own side question?
921
+ Returns "answer" or "student_question".
922
+ """
923
+ sys_prompt = (
924
+ "You are a routing model for a tutoring chat about DFM/GD&T.\n"
925
+ "Given the last question asked by the tutor and the student's reply,\n"
926
+ "decide if the student is primarily ANSWERING the question, or asking a new\n"
927
+ "QUESTION of their own (e.g., 'can I add a fillet here?').\n\n"
928
+ "Reply ONLY in JSON like {\"label\": \"answer\"} or "
929
+ "{\"label\": \"student_question\"}."
930
+ )
931
+ user_prompt = (
932
+ f"Tutor_question: {last_question}\n"
933
+ f"Student_message: {user_text}\n"
934
+ "Label:"
935
+ )
936
+ raw = run_text_llm(sys_prompt, user_prompt, max_new_tokens=64)
937
+ m = re.search(r"\{.*\}", raw, re.DOTALL)
938
+ if not m:
939
+ return "answer"
940
+ try:
941
+ obj = json.loads(m.group(0))
942
+ label = (obj.get("label") or "").lower()
943
+ if label in {"answer", "student_question"}:
944
+ return label
945
+ except Exception:
946
+ pass
947
+ return "answer"
948
+
949
+
950
+ def answer_student_question(
951
+ user_text: str,
952
+ state: GuidelineConversationState,
953
+ chat_history: List[Tuple[str, str]],
954
+ ) -> str:
955
+ """
956
+ Use the same model to answer a side-question in a friendly way.
957
+ This does NOT advance the guideline review.
958
+ """
959
+ last_q = chat_history[-1][0] if chat_history else ""
960
+ qa_snippets = []
961
+ for q, a in state.qa_log[-3:]:
962
+ qa_snippets.append(f"Q: {q}\nA: {a}")
963
+ qa_str = "\n---\n".join(qa_snippets) if qa_snippets else "(no prior Q&A)"
964
+
965
+ sys_prompt = (
966
+ "You are a friendly manufacturing / GD&T teaching assistant inside a small app.\n"
967
+ "The student may ask meta-questions like 'can I add a fillet here?', "
968
+ "'is this draft enough?', or 'what tolerance should I use?'.\n"
969
+ "Use the selected DFM/GD&T guidelines, the feature summary, and their answers\n"
970
+ "to give concrete, practical advice.\n\n"
971
+ "Prefer to reference guidelines by topic (e.g., Wall Thickness, Draft Angle).\n"
972
+ "Talk about trade-offs (manufacturability, cost, risk).\n"
973
+ "Keep answers short (2–6 sentences).\n"
974
+ "Do NOT output JSON; just respond as normal helpful text."
975
+ )
976
+ user_parts = [
977
+ "Part description:",
978
+ state.description or "(none)",
979
+ "\nFeature summary:",
980
+ json.dumps(state.feature_summary, indent=2),
981
+ "\nSelected guidelines:",
982
+ json.dumps(state.selected_guidelines, indent=2),
983
+ "\nRecent Q&A:",
984
+ qa_str,
985
+ "\nLast tutor question:",
986
+ last_q or "(none)",
987
+ "\nStudent question:",
988
+ user_text,
989
+ ]
990
+ user_prompt = "\n".join(user_parts)
991
+ reply = run_text_llm(sys_prompt, user_prompt, max_new_tokens=256)
992
+ return reply
993
+
994
+
995
+ def step_conversation(
996
+ chat_history: List[Tuple[str, str]],
997
+ user_message: str,
998
+ state: GuidelineConversationState,
999
+ ) -> Tuple[List[Tuple[str, str]], GuidelineConversationState]:
1000
+ """
1001
+ One conversation step for an ANSWER (router already decided).
1002
+ """
1003
+ # Log student's answer into QA log
1004
+ if chat_history and user_message.strip():
1005
+ last_assistant, _ = chat_history[-1]
1006
+ state.qa_log.append((last_assistant, user_message))
1007
+ state.questions_asked += 1
1008
+
1009
+ # Stopping condition
1010
+ if state.questions_asked >= state.max_questions or not current_guideline(state):
1011
+ qas_text = "\n".join([f"Q: {q}\nA: {a}" for q, a in state.qa_log])
1012
+ eval_json = evaluation_agent_txt(
1013
+ state.description,
1014
+ state.selected_guidelines,
1015
+ qas_text,
1016
+ state.feature_summary,
1017
+ )
1018
+ friendly_summary = summarize_eval_for_student(eval_json)
1019
+ chat_history.append((friendly_summary, ""))
1020
+ return chat_history, state
1021
+
1022
+ # Otherwise, determine next question
1023
+ current = current_guideline(state)
1024
+ gid = current["guideline_id"]
1025
+ topic = current["topic"]
1026
+ questions = get_guideline_questions(gid)
1027
+
1028
+ asked_for_this_topic = [q for q, _ in state.qa_log if topic in q]
1029
+ idx = len(asked_for_this_topic)
1030
+
1031
+ if idx >= len(questions):
1032
+ # move to next guideline
1033
+ state.current_guideline_idx += 1
1034
+ if not current_guideline(state):
1035
+ return step_conversation(chat_history, user_message, state)
1036
+ current = current_guideline(state)
1037
+ gid = current["guideline_id"]
1038
+ topic = current["topic"]
1039
+ questions = get_guideline_questions(gid)
1040
+ idx = 0
1041
+ if not questions:
1042
+ return step_conversation(chat_history, user_message, state)
1043
+
1044
+ q_text = questions[idx]
1045
+ header = (
1046
+ f"Now let’s look at {topic}.\n\n"
1047
+ "For this guideline, we’re checking a few key points from your DFM/GD&T rules. "
1048
+ "I’ll ask a quick question to see whether your design satisfies it.\n\n"
1049
+ )
1050
+ full_q = header + q_text
1051
+ chat_history.append((full_q, ""))
1052
+ return chat_history, state
1053
+
1054
+
1055
+ # --------- helper to convert internal tuples -> Chatbot messages ----------
1056
+
1057
+ def tuples_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, Any]]:
1058
+ """
1059
+ Convert [(assistant, user), ...] to Chatbot 'messages' format:
1060
+ [{"role": "assistant", "content": "..."},
1061
+ {"role": "user", "content": "..."}, ...]
1062
+ """
1063
+ messages: List[Dict[str, Any]] = []
1064
+ for assistant_text, user_text in history:
1065
+ if assistant_text:
1066
+ messages.append({"role": "assistant", "content": assistant_text})
1067
+ if user_text:
1068
+ messages.append({"role": "user", "content": user_text})
1069
+ return messages
1070
+
1071
+
1072
+ # ============================================================
1073
+ # 5. Gradio UI
1074
+ # ============================================================
1075
+
1076
+ with gr.Blocks(title="DFM / GD&T Manufacturability Tutor") as demo:
1077
+ gr.Markdown(
1078
+ """
1079
+ # 📐 DFM / GD&T Manufacturability Tutor
1080
+ 1. Upload **1–3 CAD screenshots or drawings**
1081
+ 2. *(Optional)* Add a short description of the part
1082
+ 3. Click **Start review**
1083
+ 4. Answer a few focused questions → get a guideline-by-guideline summary
1084
+ This tool is meant to feel like a mini design review with a friendly TA.
1085
+ """
1086
+ )
1087
+
1088
+ state = gr.State(GuidelineConversationState())
1089
+ chat_state = gr.State([]) # internal: list[Tuple[str, str]]
1090
+
1091
+ with gr.Row():
1092
+ with gr.Column(scale=3):
1093
+ chat = gr.Chatbot(
1094
+ label="Conversation",
1095
+ height=480,
1096
+ )
1097
+ user_box = gr.Textbox(
1098
+ label="Your answer or question",
1099
+ placeholder=(
1100
+ "Answer the current question, or ask something like "
1101
+ "'can I 3D print this?'"
1102
+ ),
1103
+ )
1104
+ start_btn = gr.Button("▶️ Start review (or restart)")
1105
+ with gr.Column(scale=2):
1106
+ image_input = gr.Image(
1107
+ type="numpy",
1108
+ label="Upload 1–3 CAD/drawing screenshots",
1109
+ )
1110
+ description_box = gr.Textbox(
1111
+ label="(Optional) Short description of the part",
1112
+ placeholder="e.g., 'Machined plunger for a relief valve with 60° cone'",
1113
+ )
1114
+ max_q_slider = gr.Slider(
1115
+ label="Max questions",
1116
+ minimum=3,
1117
+ maximum=12,
1118
+ value=8,
1119
+ step=1,
1120
+ )
1121
+ feature_debug = gr.JSON(
1122
+ label="Feature Summary (debug)",
1123
+ visible=False,
1124
+ )
1125
+ guideline_debug = gr.JSON(
1126
+ label="Selected Guidelines (debug)",
1127
+ visible=False,
1128
+ )
1129
+
1130
+ # ---------- Event wiring ----------
1131
+ def _start(images, desc, max_q):
1132
+ """
1133
+ Gradio callback for 'Start review (or restart)'.
1134
+ Normalize images, run feature extractor, pick guidelines,
1135
+ compose intro + first question.
1136
+ """
1137
+ if images is None:
1138
+ image_list: List[np.ndarray] = []
1139
+ elif isinstance(images, list):
1140
+ image_list = images
1141
+ else:
1142
+ image_list = [images]
1143
+
1144
+ pil_images = [Image.fromarray(img) for img in image_list] if image_list else []
1145
+ feature_summary = extract_visual_features(pil_images)
1146
+ selected = select_applicable_guidelines(
1147
+ feature_summary,
1148
+ desc or "",
1149
+ max_guidelines=5,
1150
+ )
1151
+
1152
+ state_obj = GuidelineConversationState(
1153
+ selected_guidelines=selected,
1154
+ current_guideline_idx=0,
1155
+ qa_log=[],
1156
+ max_questions=int(max_q),
1157
+ questions_asked=0,
1158
+ feature_summary=feature_summary,
1159
+ description=desc or "",
1160
+ )
1161
+
1162
+ chat_tuples: List[Tuple[str, str]] = []
1163
+ intro_msg = build_intro_message(
1164
+ desc or "",
1165
+ feature_summary,
1166
+ selected,
1167
+ int(max_q),
1168
+ )
1169
+ chat_tuples.append((intro_msg, ""))
1170
+
1171
+ # Ask first guideline question
1172
+ chat_tuples, state_obj = step_conversation(chat_tuples, "", state_obj)
1173
+
1174
+ chat_messages = tuples_to_messages(chat_tuples)
1175
+ return chat_messages, "", feature_summary, selected, state_obj, chat_tuples
1176
+
1177
+ def _answer(user_text, tuple_history, state_obj: GuidelineConversationState):
1178
+ """
1179
+ Gradio callback for the textbox submit.
1180
+ - Route the user turn to 'answer' vs 'student_question'
1181
+ - If answer → advance guideline flow
1182
+ - If student_question → chatty side-answer, no state advancement
1183
+ """
1184
+ chat_history: List[Tuple[str, str]] = tuple_history or []
1185
+ user_text = (user_text or "").strip()
1186
+ if not user_text:
1187
+ chat_messages = tuples_to_messages(chat_history)
1188
+ return chat_messages, "", state_obj, chat_history
1189
+
1190
+ last_question = chat_history[-1][0] if chat_history else ""
1191
+ label = classify_user_turn(user_text, last_question)
1192
+
1193
+ if label == "student_question":
1194
+ reply = answer_student_question(user_text, state_obj, chat_history)
1195
+ chat_history.append((reply, ""))
1196
+ chat_messages = tuples_to_messages(chat_history)
1197
+ return chat_messages, "", state_obj, chat_history
1198
+
1199
+ # label == "answer": attach answer to last question and advance
1200
+ if chat_history:
1201
+ last_q, _ = chat_history[-1]
1202
+ chat_history[-1] = (last_q, user_text)
1203
+
1204
+ chat_history, new_state = step_conversation(
1205
+ chat_history,
1206
+ user_text,
1207
+ state_obj,
1208
+ )
1209
+ chat_messages = tuples_to_messages(chat_history)
1210
+ return chat_messages, "", new_state, chat_history
1211
+
1212
+ # Button → start/restart the review
1213
+ start_btn.click(
1214
+ _start,
1215
+ inputs=[image_input, description_box, max_q_slider],
1216
+ outputs=[chat, user_box, feature_debug, guideline_debug, state, chat_state],
1217
+ )
1218
 
1219
+ # Textbox submit → route + respond
1220
+ user_box.submit(
1221
+ _answer,
1222
+ inputs=[user_box, chat_state, state],
1223
+ outputs=[chat, user_box, state, chat_state],
1224
+ )
1225
 
1226
 
1227
  if __name__ == "__main__":