hchevva commited on
Commit
42ffefc
·
verified ·
1 Parent(s): c73a29e

Update quread/llm_explain.py

Browse files
Files changed (1) hide show
  1. quread/llm_explain.py +22 -59
quread/llm_explain.py CHANGED
@@ -4,25 +4,19 @@ from __future__ import annotations
4
  from dataclasses import dataclass
5
  from typing import Any, Dict, List, Optional, Tuple
6
 
7
- import torch
8
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
 
10
 
11
  @dataclass
12
  class ExplainConfig:
13
- """
14
- Local (in-Space) explainer config.
15
- - Default model is small + reliable on CPU.
16
- - You can upgrade later (e.g., flan-t5-base/large) if performance allows.
17
- """
18
- model_id: str = "google/flan-t5-small"
19
  max_new_tokens: int = 220
20
- temperature: float = 0.2 # kept for future; seq2seq generate doesn't always use it
21
- device: str = "cpu" # Spaces free tier is CPU
22
 
23
 
24
- # --- simple in-memory cache so the model loads once per container ---
25
- _LOCAL_CACHE: Dict[str, Any] = {"model_id": None, "tokenizer": None, "model": None}
26
 
27
 
28
  def _build_grounded_prompt(
@@ -32,10 +26,6 @@ def _build_grounded_prompt(
32
  probs_top: List[Tuple[str, float]],
33
  shots: Optional[int] = None,
34
  ) -> str:
35
- """
36
- Prompt is explicitly grounded: it includes only circuit + computed outputs.
37
- The model is instructed not to invent values.
38
- """
39
  ops_lines = []
40
  for op in history:
41
  if op.get("type") == "single":
@@ -74,24 +64,16 @@ Return a concise explanation with bullet points and short paragraphs.
74
  """.strip()
75
 
76
 
77
- def _load_local_model(cfg: ExplainConfig):
78
- """
79
- Loads tokenizer+model once and caches them.
80
- Uses Seq2Seq model family (FLAN-T5) which is CPU-friendly.
81
- """
82
- if _LOCAL_CACHE["model"] is not None and _LOCAL_CACHE["model_id"] == cfg.model_id:
83
- return _LOCAL_CACHE["tokenizer"], _LOCAL_CACHE["model"]
84
 
85
  tok = AutoTokenizer.from_pretrained(cfg.model_id)
86
- model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_id)
87
 
88
- # Force CPU unless you later add GPU space
89
- model.to(cfg.device)
90
- model.eval()
91
-
92
- _LOCAL_CACHE["model_id"] = cfg.model_id
93
- _LOCAL_CACHE["tokenizer"] = tok
94
- _LOCAL_CACHE["model"] = model
95
  return tok, model
96
 
97
 
@@ -104,10 +86,6 @@ def explain_circuit_with_hf(
104
  shots: Optional[int] = None,
105
  cfg: Optional[ExplainConfig] = None,
106
  ) -> str:
107
- """
108
- Local explainer (runs inside the HF Space).
109
- Kept function name for compatibility with your app.py imports.
110
- """
111
  cfg = cfg or ExplainConfig()
112
 
113
  prompt = _build_grounded_prompt(
@@ -119,9 +97,8 @@ def explain_circuit_with_hf(
119
  )
120
 
121
  try:
122
- tok, model = _load_local_model(cfg)
123
 
124
- # Tokenize
125
  inputs = tok(
126
  prompt,
127
  return_tensors="pt",
@@ -129,31 +106,17 @@ def explain_circuit_with_hf(
129
  max_length=1024,
130
  )
131
 
132
- # Move tensors to device (CPU)
133
- for k in inputs:
134
- inputs[k] = inputs[k].to(cfg.device)
135
-
136
- # Generate
137
- with torch.no_grad():
138
- out_ids = model.generate(
139
- **inputs,
140
- max_new_tokens=int(cfg.max_new_tokens),
141
- )
142
 
143
  text = tok.decode(out_ids[0], skip_special_tokens=True).strip()
144
- if not text:
145
- return (
146
- "LLM call failed (local model returned empty output).\n\n"
147
- f"Local model: {cfg.model_id}\n"
148
- "Try increasing max_new_tokens or using flan-t5-base."
149
- )
150
- return text
151
 
152
  except Exception as e:
153
  return (
154
- "LLM call failed (local inference).\n\n"
155
- f"Local model: {cfg.model_id}\n"
156
- f"Error: {repr(e)}\n\n"
157
- "If this is an out-of-memory error, use google/flan-t5-small.\n"
158
- "If it is a missing dependency, confirm transformers + torch are in requirements.txt."
159
  )
 
4
  from dataclasses import dataclass
5
  from typing import Any, Dict, List, Optional, Tuple
6
 
7
+ from transformers import AutoTokenizer
8
+ from optimum.onnxruntime import ORTModelForSeq2SeqLM
9
 
10
 
11
  @dataclass
12
  class ExplainConfig:
13
+ # ONNX model (no torch). Keep this default.
14
+ model_id: str = "onnx-community/flan-t5-small-ONNX"
 
 
 
 
15
  max_new_tokens: int = 220
16
+ temperature: float = 0.2 # not used by ORT generate; kept for compatibility
 
17
 
18
 
19
+ _CACHE: Dict[str, Any] = {"model_id": None, "tok": None, "model": None}
 
20
 
21
 
22
  def _build_grounded_prompt(
 
26
  probs_top: List[Tuple[str, float]],
27
  shots: Optional[int] = None,
28
  ) -> str:
 
 
 
 
29
  ops_lines = []
30
  for op in history:
31
  if op.get("type") == "single":
 
64
  """.strip()
65
 
66
 
67
+ def _load_onnx(cfg: ExplainConfig):
68
+ if _CACHE["model"] is not None and _CACHE["model_id"] == cfg.model_id:
69
+ return _CACHE["tok"], _CACHE["model"]
 
 
 
 
70
 
71
  tok = AutoTokenizer.from_pretrained(cfg.model_id)
72
+ model = ORTModelForSeq2SeqLM.from_pretrained(cfg.model_id)
73
 
74
+ _CACHE["model_id"] = cfg.model_id
75
+ _CACHE["tok"] = tok
76
+ _CACHE["model"] = model
 
 
 
 
77
  return tok, model
78
 
79
 
 
86
  shots: Optional[int] = None,
87
  cfg: Optional[ExplainConfig] = None,
88
  ) -> str:
 
 
 
 
89
  cfg = cfg or ExplainConfig()
90
 
91
  prompt = _build_grounded_prompt(
 
97
  )
98
 
99
  try:
100
+ tok, model = _load_onnx(cfg)
101
 
 
102
  inputs = tok(
103
  prompt,
104
  return_tensors="pt",
 
106
  max_length=1024,
107
  )
108
 
109
+ out_ids = model.generate(
110
+ **inputs,
111
+ max_new_tokens=int(cfg.max_new_tokens),
112
+ )
 
 
 
 
 
 
113
 
114
  text = tok.decode(out_ids[0], skip_special_tokens=True).strip()
115
+ return text if text else "LLM returned empty response (ONNX). Try increasing max_new_tokens."
 
 
 
 
 
 
116
 
117
  except Exception as e:
118
  return (
119
+ "LLM call failed (ONNX local inference).\n\n"
120
+ f"Model: {cfg.model_id}\n"
121
+ f"Error: {repr(e)}"
 
 
122
  )