aadisawant2912 commited on
Commit
ce6755c
Β·
verified Β·
1 Parent(s): f083a2e

Update tools_v2.py

Browse files
Files changed (1) hide show
  1. tools_v2.py +204 -114
tools_v2.py CHANGED
@@ -1,67 +1,89 @@
1
  # =============================================================================
2
- # V2 TOOL 3 β€” label_clusters_council_of_3 (TRUE multi-LLM ensemble)
3
  # =============================================================================
4
  @tool
5
  def label_clusters_council_of_3(batch_size: int = 5) -> str:
6
- """Label each cluster using a TRUE council of 3 DIFFERENT LLMs:
7
- 1. Mistral (mistral-small-latest)
8
- 2. OpenAI (gpt-4o-mini)
9
- 3. Groq (llama3-70b-8192)
10
- Each model receives the SAME prompt independently.
11
- Final label = mode (most common) of the 3 responses.
12
- Vote agreement = unanimous / majority / split.
13
- Saves enriched summaries + full audit CSV (one row per paper) to data/v2/.
14
-
15
- API keys are read automatically from environment variables:
 
 
16
  MISTRAL_API_KEY, OPENAI_API_KEY, GROQ_API_KEY
17
  Set these in HuggingFace Space β†’ Settings β†’ Variables and Secrets.
18
 
 
 
 
19
  Args:
20
  batch_size: Clusters per LLM call (default 5).
21
  """
22
  import time
23
- import os
24
-
25
- # ── NEW: import all 3 LangChain integrations ──────────────────────────────
26
  from langchain_mistralai import ChatMistralAI
27
  from langchain_openai import ChatOpenAI
28
  from langchain_groq import ChatGroq
29
- # ─────────────────────────────────────────────────────────────────────────
30
 
31
  p = _p2()
32
  clusters = json.loads(p["clusters"].read_text())
33
 
34
- # ── NEW: define 3 real LLMs (keys picked up from env automatically) ───────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  COUNCIL = [
36
  {
37
- "name": "MISTRAL",
38
- "model": ChatMistralAI(
39
- model="mistral-small-latest",
40
- temperature=0.2,
41
- # api_key read from MISTRAL_API_KEY env var automatically
42
- ),
43
  },
44
  {
45
- "name": "OPENAI",
46
- "model": ChatOpenAI(
47
- model="gpt-4o-mini",
48
- temperature=0.2,
49
- # api_key read from OPENAI_API_KEY env var automatically
50
- ),
51
  },
52
  {
53
- "name": "GROQ",
54
- "model": ChatGroq(
55
- model="llama3-70b-8192",
56
- temperature=0.2,
57
- # api_key read from GROQ_API_KEY env var automatically
58
- ),
59
  },
60
  ]
61
  # ─────────────────────────────────────────────────────────────────────────
62
 
63
- # ── UNCHANGED: single shared prompt builder (same prompt for all 3 LLMs) ──
64
- def make_prompt(batch):
65
  mini = [
66
  {
67
  "cluster_id": c["cluster_id"],
@@ -83,61 +105,116 @@ def label_clusters_council_of_3(batch_size: int = 5) -> str:
83
  )
84
  # ─────────────────────────────────────────────────────────────────────────
85
 
86
- # ── NEW: run each LLM independently across all batches ───────────────────
87
- # persona_results[i] = { cluster_id: {label, confidence, reasoning} }
88
- # shape is identical to before so all downstream code is UNCHANGED
89
- persona_results = [{}, {}, {}]
90
- batch_starts = list(range(0, len(clusters), batch_size))
 
 
 
 
91
 
92
- for pi, member in enumerate(COUNCIL):
93
- llm = member["model"]
94
- llm_name = member["name"]
95
- all_labels = []
96
 
97
- print(f"Council member {pi+1}/3 ({llm_name}) labeling {len(clusters)} clusters...")
98
 
99
  for bi, start in enumerate(batch_starts):
100
  batch = clusters[start: start + batch_size]
101
- prompt = make_prompt(batch) # same prompt for every LLM
102
 
103
- # ── NEW: per-model error handling so one failure doesn't kill all ─
104
- try:
105
- result = _call_llm_json(llm, prompt)
106
- all_labels.extend(result)
107
- except Exception as e:
108
- print(f" WARNING: {llm_name} batch {bi} failed: {e}. Using fallback labels.")
109
- for c in batch:
110
- all_labels.append({
111
- "cluster_id": c["cluster_id"],
112
- "label": f"Cluster {c['cluster_id']} ({llm_name} error)",
113
- "confidence": "Low",
114
- "reasoning": f"Fallback β€” {llm_name} error: {str(e)[:80]}",
115
- })
116
- # ─────────────────────────────────────────────────────────────────
117
-
118
- # small delay between batches to respect rate limits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  if bi < len(batch_starts) - 1:
120
- time.sleep(8)
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- for item in all_labels:
123
- cid = int(item.get("cluster_id", 0))
124
- persona_results[pi][cid] = item
125
 
126
- # delay between council members (Groq is fast, Mistral/OpenAI need breathing room)
127
- if pi < len(COUNCIL) - 1:
128
- time.sleep(10)
 
 
 
 
 
 
 
 
129
  # ─────────────────────────────────────────────────────────────────────────
130
 
131
- # ── UNCHANGED from here down: voting + enrichment + CSV export ───────────
 
 
132
  def enrich(cluster):
133
- cid = cluster["cluster_id"]
134
  raw_votes = [
135
- str(persona_results[pi].get(cid, {}).get("label", "")).strip()
136
- for pi in range(3)
137
  ]
138
  votes = [
139
  v if v and v.lower() not in ("", "none", "null")
140
- else "Cluster {}".format(cid)
141
  for v in raw_votes
142
  ]
143
  final = _mode_label(votes)
@@ -149,62 +226,75 @@ def label_clusters_council_of_3(batch_size: int = 5) -> str:
149
  return {
150
  **cluster,
151
  "label": final,
152
- "llm_vote_1_MISTRAL": votes[0], # key renamed to match real model
153
- "llm_vote_2_OPENAI": votes[1], # key renamed to match real model
154
- "llm_vote_3_GROQ": votes[2], # key renamed to match real model
155
- "confidence_1": persona_results[0].get(cid, {}).get("confidence", ""),
156
- "confidence_2": persona_results[1].get(cid, {}).get("confidence", ""),
157
- "confidence_3": persona_results[2].get(cid, {}).get("confidence", ""),
158
- "reasoning_1": persona_results[0].get(cid, {}).get("reasoning", ""),
159
- "reasoning_2": persona_results[1].get(cid, {}).get("reasoning", ""),
160
- "reasoning_3": persona_results[2].get(cid, {}).get("reasoning", ""),
161
  "vote_agreement": agreement,
162
  }
163
 
164
  enriched = list(map(enrich, clusters))
165
  p["summaries"].write_text(json.dumps(enriched, indent=2, ensure_ascii=False))
 
166
 
167
- # Audit CSV β€” one row per paper in cluster
168
  rows = []
169
  for c in enriched:
170
  cid = c["cluster_id"]
171
  for li, paper in enumerate(c["papers"]):
172
  rows.append({
173
- "cluster_id": cid,
174
- "final_label": c["label"],
175
- "vote_agreement": c["vote_agreement"],
176
- "llm1_MISTRAL_label": c["llm_vote_1_MISTRAL"], # renamed
177
- "llm2_OPENAI_label": c["llm_vote_2_OPENAI"], # renamed
178
- "llm3_GROQ_label": c["llm_vote_3_GROQ"], # renamed
179
- "llm1_confidence": c["confidence_1"],
180
- "llm2_confidence": c["confidence_2"],
181
- "llm3_confidence": c["confidence_3"],
182
- "llm1_reasoning": c["reasoning_1"],
183
- "llm2_reasoning": c["reasoning_2"],
184
- "llm3_reasoning": c["reasoning_3"],
185
- "paper_doi": paper.get("doi", ""),
186
- "paper_title": paper.get("title", ""),
187
- "paper_year": paper.get("year", ""),
188
- "paper_journal": paper.get("journal", ""),
189
- "abstract_preview": paper.get("abstract", "")[:300],
190
- "combined_preview": paper.get("combined", "")[:200],
191
- "centroid_cosine_sim": round(float(
192
  c["centroid_sims"][li] if li < len(c["centroid_sims"]) else 0.0), 4),
193
- "hdbscan_probability": round(float(
194
  c["hdbscan_probs"][li] if li < len(c["hdbscan_probs"]) else 0.0), 4),
195
- "is_top3_centroid": "YES" if li in c["top3_paper_idx"] else "no",
196
  })
197
 
198
  pd.DataFrame(rows).to_csv(p["audit_csv"], index=False, encoding="utf-8-sig")
199
 
200
  unanimous = sum(1 for c in enriched if c["vote_agreement"] == "unanimous")
201
  majority = sum(1 for c in enriched if c["vote_agreement"] == "majority")
 
 
 
 
 
202
  return json.dumps({
203
- "clusters_labeled": len(enriched),
204
- "unanimous": unanimous,
205
- "majority": majority,
206
- "split": len(enriched) - unanimous - majority,
207
- "audit_csv_rows": len(rows),
208
- "council_members": [m["name"] for m in COUNCIL], # NEW: visible in output
209
- "note": "True 3-LLM ensemble (Mistral+OpenAI+Groq). Audit CSV ready ({} rows).".format(len(rows)),
210
- })
 
 
 
 
 
 
 
 
1
  # =============================================================================
2
+ # V2 TOOL 3 β€” label_clusters_council_of_3 (parallel + cached multi-LLM)
3
  # =============================================================================
4
  @tool
5
  def label_clusters_council_of_3(batch_size: int = 5) -> str:
6
+ """Label clusters using a TRUE council of 3 LLMs running IN PARALLEL:
7
+ 1. Mistral (mistral-small-latest)
8
+ 2. OpenAI (gpt-4o-mini)
9
+ 3. Groq (llama3-70b-8192)
10
+
11
+ SPEED: All 3 LLMs run concurrently via ThreadPoolExecutor β†’ ~3x faster.
12
+ COST: SHA-256 disk cache β€” identical prompts are NEVER sent twice.
13
+ Re-runs, retries, and reruns after crashes cost $0 for cached batches.
14
+ LIMITS: Per-model retry with exponential backoff. Groq gets a small stagger
15
+ delay so all 3 don't burst simultaneously on the first call.
16
+
17
+ API keys auto-read from env:
18
  MISTRAL_API_KEY, OPENAI_API_KEY, GROQ_API_KEY
19
  Set these in HuggingFace Space β†’ Settings β†’ Variables and Secrets.
20
 
21
+ Cache lives at: data/v2/llm_cache/
22
+ Clear the cache: delete that folder to force fresh API calls.
23
+
24
  Args:
25
  batch_size: Clusters per LLM call (default 5).
26
  """
27
  import time
28
+ import hashlib
29
+ import threading
30
+ from concurrent.futures import ThreadPoolExecutor, as_completed
31
  from langchain_mistralai import ChatMistralAI
32
  from langchain_openai import ChatOpenAI
33
  from langchain_groq import ChatGroq
 
34
 
35
  p = _p2()
36
  clusters = json.loads(p["clusters"].read_text())
37
 
38
+ # ── 1. DISK CACHE SETUP ──────────────────────────────────────────────────
39
+ # Each unique (model_name + prompt) gets its own JSON file.
40
+ # Hit β†’ free, instant, no API call.
41
+ # Miss β†’ call API, save result, never pay again for that prompt.
42
+ CACHE_DIR = p["dir"] / "llm_cache"
43
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
44
+ cache_lock = threading.Lock() # safe for concurrent reads/writes
45
+
46
+ def _cache_key(model_name: str, prompt: str) -> str:
47
+ digest = hashlib.sha256(f"{model_name}::{prompt}".encode()).hexdigest()
48
+ return digest
49
+
50
+ def _cache_get(model_name: str, prompt: str):
51
+ key = _cache_key(model_name, prompt)
52
+ path = CACHE_DIR / f"{key}.json"
53
+ with cache_lock:
54
+ if path.exists():
55
+ return json.loads(path.read_text(encoding="utf-8"))
56
+ return None # cache miss
57
+
58
+ def _cache_set(model_name: str, prompt: str, result):
59
+ key = _cache_key(model_name, prompt)
60
+ path = CACHE_DIR / f"{key}.json"
61
+ with cache_lock:
62
+ path.write_text(json.dumps(result, ensure_ascii=False), encoding="utf-8")
63
+ # ─────────────────────────────────────────────────────────────────────────
64
+
65
+ # ── 2. COUNCIL DEFINITION ────────────────────────────────────────────────
66
  COUNCIL = [
67
  {
68
+ "name": "MISTRAL",
69
+ "model": ChatMistralAI(model="mistral-small-latest", temperature=0.2),
70
+ "stagger": 0, # seconds to wait before first call
 
 
 
71
  },
72
  {
73
+ "name": "OPENAI",
74
+ "model": ChatOpenAI(model="gpt-4o-mini", temperature=0.2),
75
+ "stagger": 1, # slight stagger so 3 don't burst at t=0
 
 
 
76
  },
77
  {
78
+ "name": "GROQ",
79
+ "model": ChatGroq(model="llama3-70b-8192", temperature=0.2),
80
+ "stagger": 2,
 
 
 
81
  },
82
  ]
83
  # ─────────────────────────────────────────────────────────────────────────
84
 
85
+ # ── 3. SHARED PROMPT BUILDER (same for all 3 LLMs) ───────────────────────
86
+ def make_prompt(batch: list) -> str:
87
  mini = [
88
  {
89
  "cluster_id": c["cluster_id"],
 
105
  )
106
  # ─────────────────────────────────────────────────────────────────────────
107
 
108
+ # ── 4. SINGLE-MEMBER WORKER (runs in its own thread) ─────────────────────
109
+ # Handles: cache check β†’ stagger β†’ retry-with-backoff β†’ cache save
110
+ # Returns: { cluster_id: {label, confidence, reasoning} }
111
+ def run_one_member(member: dict) -> tuple[str, dict]:
112
+ """Returns (member_name, {cid: result_dict})"""
113
+ name = member["name"]
114
+ llm = member["model"]
115
+ stagger = member["stagger"]
116
+ results = {}
117
 
118
+ # small stagger so 3 threads don't all burst at the exact same millisecond
119
+ if stagger:
120
+ time.sleep(stagger)
 
121
 
122
+ batch_starts = list(range(0, len(clusters), batch_size))
123
 
124
  for bi, start in enumerate(batch_starts):
125
  batch = clusters[start: start + batch_size]
126
+ prompt = make_prompt(batch)
127
 
128
+ # ── cache check (free) ──────────────────────────────────────────
129
+ cached = _cache_get(name, prompt)
130
+ if cached is not None:
131
+ print(f" [{name}] batch {bi+1}/{len(batch_starts)} β†’ CACHE HIT (free)")
132
+ for item in cached:
133
+ results[int(item.get("cluster_id", 0))] = item
134
+ continue # skip API call entirely
135
+ # ───────────────────────────────────────────────────────────────
136
+
137
+ # ── API call with exponential backoff ───────────────────────────
138
+ MAX_RETRIES = 4
139
+ for attempt in range(MAX_RETRIES):
140
+ try:
141
+ print(f" [{name}] batch {bi+1}/{len(batch_starts)} attempt {attempt+1}")
142
+ batch_result = _call_llm_json(llm, prompt)
143
+
144
+ # save to cache immediately on success
145
+ _cache_set(name, prompt, batch_result)
146
+
147
+ for item in batch_result:
148
+ results[int(item.get("cluster_id", 0))] = item
149
+ break # success β†’ exit retry loop
150
+
151
+ except Exception as e:
152
+ wait = (2 ** attempt) * 15 # 15s, 30s, 60s, 120s
153
+ print(f" [{name}] batch {bi+1} attempt {attempt+1} FAILED: {e}")
154
+
155
+ if attempt < MAX_RETRIES - 1:
156
+ print(f" [{name}] retrying in {wait}s...")
157
+ time.sleep(wait)
158
+ else:
159
+ # all retries exhausted β†’ use fallback, do NOT crash
160
+ print(f" [{name}] all retries exhausted, using fallback for batch {bi+1}")
161
+ for c in batch:
162
+ cid = c["cluster_id"]
163
+ results[cid] = {
164
+ "cluster_id": cid,
165
+ "label": f"Cluster {cid} ({name} error)",
166
+ "confidence": "Low",
167
+ "reasoning": f"Fallback β€” {name} failed: {str(e)[:80]}",
168
+ }
169
+ # ───────────────────────────────────────────────────────────────
170
+
171
+ # ── inter-batch delay (only for non-cached batches) ─────────────
172
+ # Groq is very fast but strict on RPM; Mistral/OpenAI need breathing room.
173
+ # We sleep INSIDE each thread so they don't interfere with each other.
174
+ BATCH_DELAYS = {"MISTRAL": 12, "OPENAI": 10, "GROQ": 20}
175
  if bi < len(batch_starts) - 1:
176
+ time.sleep(BATCH_DELAYS.get(name, 12))
177
+ # ───────────────────────────────────────────────────────────────
178
+
179
+ return name, results
180
+ # ─────────────────────────────────────────────────────────────────────────
181
+
182
+ # ── 5. PARALLEL DISPATCH ─────────────────────────────────────────────────
183
+ # All 3 threads run simultaneously. Wall time β‰ˆ slowest single member,
184
+ # not sum of all three. Thread count = 3 (one per LLM).
185
+ persona_results = {} # { "MISTRAL": {cid: ...}, ... }
186
+ cache_hits = 0
187
+ cache_misses = 0
188
 
189
+ print("Dispatching 3 LLMs in parallel...")
190
+ with ThreadPoolExecutor(max_workers=3) as executor:
191
+ futures = {executor.submit(run_one_member, m): m["name"] for m in COUNCIL}
192
 
193
+ for future in as_completed(futures):
194
+ member_name = futures[future]
195
+ try:
196
+ name, result_dict = future.result()
197
+ persona_results[name] = result_dict
198
+ print(f"[DONE] {name} finished with {len(result_dict)} cluster labels")
199
+ except Exception as e:
200
+ # should never reach here (worker handles its own errors),
201
+ # but belt-and-suspenders just in case
202
+ print(f"[ERROR] {member_name} thread crashed unexpectedly: {e}")
203
+ persona_results[member_name] = {}
204
  # ─────────────────────────────────────────────────────────────────────────
205
 
206
+ # ── 6. VOTING + ENRICHMENT (unchanged logic) ─────────────────────────────
207
+ LLM_NAMES = ["MISTRAL", "OPENAI", "GROQ"]
208
+
209
  def enrich(cluster):
210
+ cid = cluster["cluster_id"]
211
  raw_votes = [
212
+ str(persona_results.get(name, {}).get(cid, {}).get("label", "")).strip()
213
+ for name in LLM_NAMES
214
  ]
215
  votes = [
216
  v if v and v.lower() not in ("", "none", "null")
217
+ else f"Cluster {cid}"
218
  for v in raw_votes
219
  ]
220
  final = _mode_label(votes)
 
226
  return {
227
  **cluster,
228
  "label": final,
229
+ "llm_vote_1_MISTRAL": votes[0],
230
+ "llm_vote_2_OPENAI": votes[1],
231
+ "llm_vote_3_GROQ": votes[2],
232
+ "confidence_1": persona_results.get("MISTRAL", {}).get(cid, {}).get("confidence", ""),
233
+ "confidence_2": persona_results.get("OPENAI", {}).get(cid, {}).get("confidence", ""),
234
+ "confidence_3": persona_results.get("GROQ", {}).get(cid, {}).get("confidence", ""),
235
+ "reasoning_1": persona_results.get("MISTRAL", {}).get(cid, {}).get("reasoning", ""),
236
+ "reasoning_2": persona_results.get("OPENAI", {}).get(cid, {}).get("reasoning", ""),
237
+ "reasoning_3": persona_results.get("GROQ", {}).get(cid, {}).get("reasoning", ""),
238
  "vote_agreement": agreement,
239
  }
240
 
241
  enriched = list(map(enrich, clusters))
242
  p["summaries"].write_text(json.dumps(enriched, indent=2, ensure_ascii=False))
243
+ # ─────────────────────────────────────────────────────────────────────────
244
 
245
+ # ── 7. AUDIT CSV (unchanged format) ──────────────────────────────────────
246
  rows = []
247
  for c in enriched:
248
  cid = c["cluster_id"]
249
  for li, paper in enumerate(c["papers"]):
250
  rows.append({
251
+ "cluster_id": cid,
252
+ "final_label": c["label"],
253
+ "vote_agreement": c["vote_agreement"],
254
+ "llm1_MISTRAL_label": c["llm_vote_1_MISTRAL"],
255
+ "llm2_OPENAI_label": c["llm_vote_2_OPENAI"],
256
+ "llm3_GROQ_label": c["llm_vote_3_GROQ"],
257
+ "llm1_confidence": c["confidence_1"],
258
+ "llm2_confidence": c["confidence_2"],
259
+ "llm3_confidence": c["confidence_3"],
260
+ "llm1_reasoning": c["reasoning_1"],
261
+ "llm2_reasoning": c["reasoning_2"],
262
+ "llm3_reasoning": c["reasoning_3"],
263
+ "paper_doi": paper.get("doi", ""),
264
+ "paper_title": paper.get("title", ""),
265
+ "paper_year": paper.get("year", ""),
266
+ "paper_journal": paper.get("journal", ""),
267
+ "abstract_preview": paper.get("abstract", "")[:300],
268
+ "combined_preview": paper.get("combined", "")[:200],
269
+ "centroid_cosine_sim": round(float(
270
  c["centroid_sims"][li] if li < len(c["centroid_sims"]) else 0.0), 4),
271
+ "hdbscan_probability": round(float(
272
  c["hdbscan_probs"][li] if li < len(c["hdbscan_probs"]) else 0.0), 4),
273
+ "is_top3_centroid": "YES" if li in c["top3_paper_idx"] else "no",
274
  })
275
 
276
  pd.DataFrame(rows).to_csv(p["audit_csv"], index=False, encoding="utf-8-sig")
277
 
278
  unanimous = sum(1 for c in enriched if c["vote_agreement"] == "unanimous")
279
  majority = sum(1 for c in enriched if c["vote_agreement"] == "majority")
280
+
281
+ # count cache hits by checking what's in cache_dir vs how many API calls were made
282
+ total_batches = len(list(range(0, len(clusters), batch_size))) * 3
283
+ cached_files = len(list(CACHE_DIR.glob("*.json")))
284
+
285
  return json.dumps({
286
+ "clusters_labeled": len(enriched),
287
+ "unanimous": unanimous,
288
+ "majority": majority,
289
+ "split": len(enriched) - unanimous - majority,
290
+ "audit_csv_rows": len(rows),
291
+ "council_members": LLM_NAMES,
292
+ "execution": "parallel (ThreadPoolExecutor, 3 workers)",
293
+ "cache_files_on_disk": cached_files,
294
+ "cache_dir": str(CACHE_DIR),
295
+ "note": (
296
+ "Parallel 3-LLM ensemble done. "
297
+ f"Cache has {cached_files} entries β€” re-runs use these for free. "
298
+ "Audit CSV ready ({} rows).".format(len(rows))
299
+ ),
300
+ })--how this where to paste this but