hchevva commited on
Commit
8a60fe7
·
verified ·
1 Parent(s): 6b69023

Update quread/llm_explain.py

Browse files
Files changed (1) hide show
  1. quread/llm_explain.py +70 -65
quread/llm_explain.py CHANGED
@@ -1,19 +1,28 @@
1
  # quread/llm_explain.py
2
  from __future__ import annotations
3
 
4
- import os
5
  from dataclasses import dataclass
6
  from typing import Any, Dict, List, Optional, Tuple
7
 
8
- from huggingface_hub import InferenceClient
 
9
 
10
 
11
  @dataclass
12
  class ExplainConfig:
13
- model_id: str = "HuggingFaceH4/zephyr-7b-beta" # you can change later
14
- provider: str = "hf-inference"
15
- max_new_tokens: int = 280
16
- temperature: float = 0.2
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def _build_grounded_prompt(
@@ -37,7 +46,6 @@ def _build_grounded_prompt(
37
  ops_lines.append(f"- {op}")
38
 
39
  top_lines = [f"- {b}: {p:.4f}" for b, p in probs_top]
40
-
41
  shots_line = f"Shots: {shots}\n" if shots is not None else ""
42
 
43
  return f"""
@@ -66,8 +74,25 @@ Return a concise explanation with bullet points and short paragraphs.
66
  """.strip()
67
 
68
 
69
- def _get_token() -> Optional[str]:
70
- return os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
  def explain_circuit_with_hf(
@@ -79,12 +104,11 @@ def explain_circuit_with_hf(
79
  shots: Optional[int] = None,
80
  cfg: Optional[ExplainConfig] = None,
81
  ) -> str:
 
 
 
 
82
  cfg = cfg or ExplainConfig()
83
- token = _get_token()
84
- if not token:
85
- return "HF_TOKEN is not set (Space Settings → Secrets → HF_TOKEN)."
86
-
87
- client = InferenceClient(provider=cfg.provider, token=token)
88
 
89
  prompt = _build_grounded_prompt(
90
  n_qubits=n_qubits,
@@ -94,61 +118,42 @@ def explain_circuit_with_hf(
94
  shots=shots,
95
  )
96
 
97
- last_error = None
98
-
99
- # 1) Chat completion
100
  try:
101
- resp = client.chat_completion(
102
- model=cfg.model_id,
103
- messages=[
104
- {"role": "system", "content": "You are a helpful quantum tutor."},
105
- {"role": "user", "content": prompt},
106
- ],
107
- max_tokens=cfg.max_new_tokens,
108
- temperature=cfg.temperature,
109
  )
110
- text = resp.choices[0].message.content if resp and resp.choices else ""
111
- if text and text.strip():
112
- return text.strip()
113
- last_error = ValueError("chat_completion returned empty text")
114
- except Exception as e:
115
- last_error = e
116
 
117
- # 2) Text generation
118
- try:
119
- out = client.text_generation(
120
- model=cfg.model_id,
121
- prompt=prompt,
122
- max_new_tokens=cfg.max_new_tokens,
123
- temperature=cfg.temperature,
124
- )
125
- if out and str(out).strip():
126
- return str(out).strip()
127
- last_error = ValueError("text_generation returned empty text")
128
- except Exception as e:
129
- last_error = e
130
 
131
- # 3) Text2Text only if the method exists AND the model looks like T5/FLAN
132
- try:
133
- is_t5_family = any(x in cfg.model_id.lower() for x in ["t5", "flan"])
134
- fn = getattr(client, "text2text_generation", None)
135
-
136
- if is_t5_family and fn is not None:
137
- out = fn(
138
- model=cfg.model_id,
139
- prompt=prompt,
140
- max_new_tokens=cfg.max_new_tokens,
141
  )
142
- if out and str(out).strip():
143
- return str(out).strip()
144
- last_error = ValueError("text2text_generation returned empty text")
 
 
 
 
 
 
145
 
146
  except Exception as e:
147
- last_error = e
148
-
149
- return (
150
- "LLM call failed.\n\n"
151
- f"Model: {cfg.model_id}\n"
152
- f"Provider: {cfg.provider}\n"
153
- f"Error: {repr(last_error)}"
154
- )
 
1
  # quread/llm_explain.py
2
  from __future__ import annotations
3
 
 
4
  from dataclasses import dataclass
5
  from typing import Any, Dict, List, Optional, Tuple
6
 
7
+ import torch
8
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
 
10
 
11
  @dataclass
12
  class ExplainConfig:
13
+ """
14
+ Local (in-Space) explainer config.
15
+ - Default model is small + reliable on CPU.
16
+ - You can upgrade later (e.g., flan-t5-base/large) if performance allows.
17
+ """
18
+ model_id: str = "google/flan-t5-small"
19
+ max_new_tokens: int = 220
20
+ temperature: float = 0.2 # kept for future; seq2seq generate doesn't always use it
21
+ device: str = "cpu" # Spaces free tier is CPU
22
+
23
+
24
+ # --- simple in-memory cache so the model loads once per container ---
25
+ _LOCAL_CACHE: Dict[str, Any] = {"model_id": None, "tokenizer": None, "model": None}
26
 
27
 
28
  def _build_grounded_prompt(
 
46
  ops_lines.append(f"- {op}")
47
 
48
  top_lines = [f"- {b}: {p:.4f}" for b, p in probs_top]
 
49
  shots_line = f"Shots: {shots}\n" if shots is not None else ""
50
 
51
  return f"""
 
74
  """.strip()
75
 
76
 
77
+ def _load_local_model(cfg: ExplainConfig):
78
+ """
79
+ Loads tokenizer+model once and caches them.
80
+ Uses Seq2Seq model family (FLAN-T5) which is CPU-friendly.
81
+ """
82
+ if _LOCAL_CACHE["model"] is not None and _LOCAL_CACHE["model_id"] == cfg.model_id:
83
+ return _LOCAL_CACHE["tokenizer"], _LOCAL_CACHE["model"]
84
+
85
+ tok = AutoTokenizer.from_pretrained(cfg.model_id)
86
+ model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_id)
87
+
88
+ # Force CPU unless you later add GPU space
89
+ model.to(cfg.device)
90
+ model.eval()
91
+
92
+ _LOCAL_CACHE["model_id"] = cfg.model_id
93
+ _LOCAL_CACHE["tokenizer"] = tok
94
+ _LOCAL_CACHE["model"] = model
95
+ return tok, model
96
 
97
 
98
  def explain_circuit_with_hf(
 
104
  shots: Optional[int] = None,
105
  cfg: Optional[ExplainConfig] = None,
106
  ) -> str:
107
+ """
108
+ Local explainer (runs inside the HF Space).
109
+ Kept function name for compatibility with your app.py imports.
110
+ """
111
  cfg = cfg or ExplainConfig()
 
 
 
 
 
112
 
113
  prompt = _build_grounded_prompt(
114
  n_qubits=n_qubits,
 
118
  shots=shots,
119
  )
120
 
 
 
 
121
  try:
122
+ tok, model = _load_local_model(cfg)
123
+
124
+ # Tokenize
125
+ inputs = tok(
126
+ prompt,
127
+ return_tensors="pt",
128
+ truncation=True,
129
+ max_length=1024,
130
  )
 
 
 
 
 
 
131
 
132
+ # Move tensors to device (CPU)
133
+ for k in inputs:
134
+ inputs[k] = inputs[k].to(cfg.device)
 
 
 
 
 
 
 
 
 
 
135
 
136
+ # Generate
137
+ with torch.no_grad():
138
+ out_ids = model.generate(
139
+ **inputs,
140
+ max_new_tokens=int(cfg.max_new_tokens),
 
 
 
 
 
141
  )
142
+
143
+ text = tok.decode(out_ids[0], skip_special_tokens=True).strip()
144
+ if not text:
145
+ return (
146
+ "LLM call failed (local model returned empty output).\n\n"
147
+ f"Local model: {cfg.model_id}\n"
148
+ "Try increasing max_new_tokens or using flan-t5-base."
149
+ )
150
+ return text
151
 
152
  except Exception as e:
153
+ return (
154
+ "LLM call failed (local inference).\n\n"
155
+ f"Local model: {cfg.model_id}\n"
156
+ f"Error: {repr(e)}\n\n"
157
+ "If this is an out-of-memory error, use google/flan-t5-small.\n"
158
+ "If it is a missing dependency, confirm transformers + torch are in requirements.txt."
159
+ )