Wckd314 commited on
Commit
e79361f
Β·
verified Β·
1 Parent(s): f40b3e2

Update utils/llm_client.py

Browse files
Files changed (1) hide show
  1. utils/llm_client.py +608 -603
utils/llm_client.py CHANGED
@@ -1,603 +1,608 @@
1
- """
2
- Pundit Feynman LLM Client β€” 3-Stage Pipeline
3
- Stage 1: Analyze (images β†’ structured JSON analysis)
4
- Stage 2: Design (analysis β†’ implementation plan JSON)
5
- Stage 3: Generate (analysis + design β†’ notebook cells JSON)
6
- """
7
-
8
- import os
9
- import json
10
- import time
11
- import re
12
- import requests
13
- from openai import OpenAI
14
- from dotenv import load_dotenv
15
-
16
- load_dotenv()
17
-
18
- # ── Configuration ──────────────────────────────────────────────────────────
19
- API_KEY = os.getenv("NVIDIA_API_KEY", "")
20
- BASE_URL = os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1")
21
- MODEL = os.getenv("LLM_MODEL", "qwen/qwen3.5-397b-a17b")
22
- MAX_IMAGES_PER_REQUEST = int(os.getenv("MAX_IMAGES_PER_REQUEST", "8"))
23
-
24
- # OCR Configuration
25
- OCR_API_KEY = os.getenv("NVIDIA_OCR_API_KEY", "")
26
- OCR_API_URL = "https://ai.api.nvidia.com/v1/cv/nvidia/nemoretriever-ocr-v1"
27
-
28
- # FLUX.1-schnell Image Generation
29
- FLUX_API_KEY = os.getenv("NVIDIA_FLUX_API_KEY", "")
30
- FLUX_API_URL = "https://ai.api.nvidia.com/v1/genai/black-forest-labs/flux.1-schnell"
31
-
32
- MAX_RETRIES = 3
33
- RETRY_DELAYS = [5, 15, 30]
34
-
35
- client = OpenAI(base_url=BASE_URL, api_key=API_KEY)
36
-
37
-
38
- # ── Prompts ────────────────────────────────────────────────────────────────
39
-
40
- SYSTEM_PROMPT = (
41
- "You are an expert research engineer and educator who converts academic papers into "
42
- "clear, educational, executable Python code. You produce structured JSON output for "
43
- "each stage of the pipeline. When building toy implementations, you create REAL working code "
44
- "(PyTorch, Transformer layers, actual training loops) at reduced scale that "
45
- "runs on CPU. You prioritize faithful replication of the paper's architecture "
46
- "and algorithms while making the code deeply educational with clear explanations, "
47
- "using the Feynman technique to break down complex math into simple analogies, "
48
- "verbose logging, and insightful visualizations."
49
- )
50
-
51
- ANALYSIS_PROMPT = """Analyze this research paper text and return a JSON object with:
52
- {
53
- "title": "exact paper title",
54
- "authors": ["author names"],
55
- "research_field": "e.g. NLP, Computer Vision, RL",
56
- "abstract_summary": "2-3 sentence plain English summary of the paper",
57
- "feynman_analogy": "A brilliant, everyday analogy that maps perfectly to the paper's core key_insight (e.g., comparing attention mechanisms to a cocktail party)",
58
- "feynman_core_concept": "Explain the paper's main idea as if teaching a bright 12-year-old, using the analogy above, in 3-5 sentences",
59
- "key_insight": "the core novel contribution in one sentence",
60
- "algorithms": [
61
- {
62
- "name": "algorithm name",
63
- "purpose": "what it does",
64
- "key_equations": ["important formulas in LaTeX notation"],
65
- "pseudocode_steps": ["step1", "step2"]
66
- }
67
- ],
68
- "architecture": {
69
- "type": "e.g. Transformer, CNN, GAN",
70
- "components": ["list of main components"],
71
- "data_flow": "description of how data flows through the model"
72
- },
73
- "datasets_mentioned": ["dataset names"],
74
- "implementation_requirements": {
75
- "frameworks": ["PyTorch"],
76
- "key_hyperparameters": {"param": "value"},
77
- "estimated_complexity": "low/medium/high for toy version"
78
- }
79
- }
80
-
81
- Return ONLY valid JSON, no markdown, no extra text."""
82
-
83
- DESIGN_PROMPT = """Based on this paper analysis, create a toy implementation design that runs on CPU.
84
- Return a JSON object with:
85
- {
86
- "model_architecture": {
87
- "type": "architecture type",
88
- "embed_dim": 64,
89
- "num_layers": 2,
90
- "num_heads": 4,
91
- "vocab_size": 1000,
92
- "max_seq_len": 64,
93
- "components": [
94
- {
95
- "name": "component name",
96
- "class_name": "PythonClassName",
97
- "description": "what this component does",
98
- "key_params": {"param": "value"}
99
- }
100
- ]
101
- },
102
- "training_config": {
103
- "optimizer": "Adam",
104
- "learning_rate": 0.001,
105
- "num_epochs": 5,
106
- "batch_size": 16,
107
- "loss_function": "CrossEntropyLoss",
108
- "dataset_strategy": "synthetic generation approach"
109
- },
110
- "visualization_plan": [
111
- "loss curve",
112
- "attention heatmap",
113
- "sample predictions"
114
- ],
115
- "estimated_cells": 15,
116
- "code_structure": [
117
- {"section": "imports", "description": "required libraries"},
118
- {"section": "model", "description": "model architecture classes"},
119
- {"section": "data", "description": "synthetic data generation"},
120
- {"section": "training", "description": "training loop"},
121
- {"section": "evaluation", "description": "testing and visualization"}
122
- ]
123
- }
124
-
125
- Return ONLY valid JSON, no markdown, no extra text."""
126
-
127
- GENERATE_PROMPT_TEMPLATE = """You are generating a Jupyter notebook from a paper analysis and implementation design.
128
- Analysis: {analysis}
129
- Design: {design}
130
-
131
- Note: You are a 397B parameter model (Qwen 3.5) with 17B actively used parameters (MoE architecture).
132
- This means you have deep expertise and vast knowledge. Use it to produce genuinely educational content.
133
-
134
- Return a JSON array of notebook cells following this **exact 13-section structure**:
135
-
136
- 1. **Title & Overview** (markdown) β€” Paper title, authors, a one-paragraph summary of the paper.
137
-
138
- 2. **Table of Contents** (markdown) β€” Numbered list of all 13 sections. Each section name should be a clickable anchor link.
139
-
140
- 3. **The Feynman Explanation** (markdown) β€” A step-by-step explanation of the WHOLE paper using the Feynman technique. Break down the core algorithms, math, and architecture into the absolute simplest terms possible. Expand heavily on the `feynman_analogy` and `feynman_core_concept` from the analysis. Use relatable, everyday analogies for each major step so a beginner can intuitively grasp how the system works before seeing the code.
141
-
142
- 4. **Environment Setup** (code) β€” pip installs and imports. Include `torch`, `numpy`, `matplotlib`, and any other needed libraries.
143
-
144
- 5. **Configuration & Hyperparameters** (code) β€” A single config dict or dataclass with all hyperparameters. Add comments explaining each.
145
-
146
- 6. **Data Preparation** (code) β€” Synthetic dataset generation or loading. Must produce realistic dummy data matching the paper's domain.
147
-
148
- 7. **Model Architecture** (code) β€” Full PyTorch model implementation. Use `nn.Module` subclasses with detailed docstrings about each component. Include shape comments.
149
-
150
- 8. **Training Loop** (code) β€” Complete training loop with loss tracking, progress printing, and gradient clipping.
151
-
152
- 9. **Training Execution** (code) β€” Run the training and display results.
153
-
154
- 10. **Evaluation & Metrics** (code) β€” Run inference on test data and compute relevant metrics.
155
-
156
- 11. **Visualizations** (code) β€” Matplotlib charts: loss curves, attention heatmaps or feature maps, sample predictions.
157
-
158
- 12. **Key Takeaways** (markdown) β€” Bullet-point summary of what was learned, what would change at full scale, potential improvements.
159
-
160
- 13. **References** (markdown) β€” Paper citation, related work links, library documentation links.
161
-
162
- Each cell in the JSON array must have:
163
- {{"cell_type": "code" or "markdown", "source": "cell content as a string"}}
164
-
165
- RULES:
166
- - All code must be executable on CPU
167
- - Use educational variable names and heavy commenting
168
- - Include print() statements showing tensor shapes and intermediate results
169
- - Follow the 13-section structure exactly
170
- - Minimum 15 cells total
171
- - The Feynman Explanation should be at least 300 words
172
- - Return ONLY the JSON array, no markdown fences"""
173
-
174
-
175
- # ── OCR extraction (NVIDIA NeMo Retriever OCR v1) ─────────────────────────
176
-
177
- def extract_text_from_images(base64_images):
178
- """Extract text from paper page images using NVIDIA NeMo Retriever OCR API.
179
- Sends page images to the dedicated OCR model for fast, accurate extraction.
180
- Falls back to page-by-page if a batch request fails.
181
- """
182
- all_text = []
183
- headers = {
184
- "Authorization": f"Bearer {OCR_API_KEY}",
185
- "Accept": "application/json",
186
- "Content-Type": "application/json",
187
- }
188
-
189
- total = len(base64_images)
190
- print(f" OCR: Processing {total} pages via NVIDIA NeMo Retriever...")
191
-
192
- for page_idx, img_b64 in enumerate(base64_images):
193
- print(f" Page {page_idx + 1}/{total}...")
194
-
195
- payload = {
196
- "input": [
197
- {
198
- "type": "image_url",
199
- "url": f"data:image/jpeg;base64,{img_b64}"
200
- }
201
- ],
202
- "merge_levels": ["paragraph"]
203
- }
204
-
205
- try:
206
- resp = requests.post(
207
- OCR_API_URL,
208
- headers=headers,
209
- json=payload,
210
- timeout=60,
211
- )
212
- resp.raise_for_status()
213
- result = resp.json()
214
-
215
- # Extract text from OCR response
216
- page_text = _parse_ocr_response(result, page_idx + 1)
217
- if page_text:
218
- all_text.append(page_text)
219
-
220
- except Exception as e:
221
- print(f" \u26a0 OCR failed for page {page_idx + 1}: {e}")
222
- # Continue with remaining pages
223
- continue
224
-
225
- if not all_text:
226
- raise RuntimeError("OCR failed: No text extracted from any page")
227
-
228
- combined = "\n\n".join(all_text)
229
- print(f" OCR complete: {len(combined)} chars from {len(all_text)}/{total} pages")
230
- return combined
231
-
232
-
233
- def _parse_ocr_response(response_json, page_num):
234
- """Parse the NVIDIA OCR API response into clean text.
235
- Response format: {"data": [{"text_detections": [{"text_prediction": {"text": ..., "confidence": ...}}]}]}
236
- """
237
- texts = []
238
- try:
239
- for item in response_json.get("data", []):
240
- for detection in item.get("text_detections", []):
241
- pred = detection.get("text_prediction", {})
242
- text = pred.get("text", "").strip()
243
- confidence = pred.get("confidence", 0)
244
- # Only include text with reasonable confidence
245
- if text and confidence > 0.3:
246
- texts.append(text)
247
- except Exception as e:
248
- print(f" \u26a0 Error parsing OCR response for page {page_num}: {e}")
249
- return ""
250
-
251
- return "\n".join(texts)
252
-
253
-
254
- # ── LLM Call with Retry ───────────────────────────────────────────────────
255
-
256
- def call_with_retry(messages, max_tokens=4096, temperature=0.3, stream=False):
257
- """Call the LLM API with retry logic for transient errors."""
258
- last_error = None
259
-
260
- for attempt in range(MAX_RETRIES):
261
- try:
262
- kwargs = dict(
263
- model=MODEL,
264
- messages=messages,
265
- max_tokens=max_tokens,
266
- temperature=temperature,
267
- timeout=300,
268
- )
269
- if stream:
270
- kwargs["stream"] = True
271
- return client.chat.completions.create(**kwargs)
272
- else:
273
- response = client.chat.completions.create(**kwargs)
274
- return response.choices[0].message.content
275
-
276
- except Exception as e:
277
- error_str = str(e).lower()
278
- if any(kw in error_str for kw in ["429", "rate", "500", "503", "overloaded", "unavailable"]):
279
- last_error = e
280
- wait = RETRY_DELAYS[min(attempt, len(RETRY_DELAYS) - 1)]
281
- print(f" ⚠ Transient error. Waiting {wait}s before retry {attempt + 1}/{MAX_RETRIES}...")
282
- time.sleep(wait)
283
- else:
284
- raise
285
-
286
- raise RuntimeError(f"Failed after {MAX_RETRIES} retries. Last error: {last_error}")
287
-
288
-
289
- # ── JSON Parsing ──────────────────────────────────────────────────────────
290
-
291
- def parse_llm_json(raw_text, step_name):
292
- """Parse JSON from LLM response, with cleanup and one repair attempt."""
293
- if raw_text is None:
294
- print(f" ⚠ LLM returned None for {step_name}")
295
- return {}
296
- text = raw_text.strip()
297
-
298
- # Strip markdown code fences if present
299
- if text.startswith("```"):
300
- first_newline = text.index("\n")
301
- text = text[first_newline + 1:]
302
- if text.endswith("```"):
303
- text = text[:-3]
304
- text = text.strip()
305
-
306
- # Try direct parse
307
- try:
308
- return json.loads(text)
309
- except json.JSONDecodeError as e:
310
- print(f" ⚠ JSON parse failed in {step_name}. Attempting repair...")
311
-
312
- # Attempt auto-repair via LLM
313
- repair_prompt = (
314
- f"The following text was supposed to be valid JSON but has a syntax error:\n\n"
315
- f"{text[:6000]}\n\n"
316
- f"Error: {e}\n\n"
317
- f"Return ONLY the corrected valid JSON, nothing else."
318
- )
319
- repaired = call_with_retry(
320
- messages=[
321
- {"role": "system", "content": "You are a JSON repair tool. Return only valid JSON."},
322
- {"role": "user", "content": repair_prompt},
323
- ],
324
- max_tokens=max(len(text) // 2, 4096),
325
- temperature=0.1,
326
- )
327
- if repaired is None:
328
- raise ValueError(f"Could not repair JSON from {step_name} β€” LLM returned None")
329
- repaired = repaired.strip()
330
- if repaired.startswith("```"):
331
- repaired = repaired.split("\n", 1)[1]
332
- if repaired.endswith("```"):
333
- repaired = repaired[:-3]
334
-
335
- try:
336
- return json.loads(repaired.strip())
337
- except json.JSONDecodeError:
338
- # Last resort: try to extract JSON from the text
339
- json_match = re.search(r'[\[{].*[\]}]', repaired.strip(), re.DOTALL)
340
- if json_match:
341
- return json.loads(json_match.group())
342
- raise ValueError(f"Could not parse JSON from {step_name} even after repair.")
343
-
344
-
345
- # ── Pipeline Stages ───────────────────────────────────────────────────────
346
-
347
- def analyze_paper(raw_text):
348
- """Stage 1: Analyze extracted text into structured JSON."""
349
- messages = [
350
- {"role": "system", "content": SYSTEM_PROMPT},
351
- {"role": "user", "content": f"{ANALYSIS_PROMPT}\n\n--- EXTRACTED PAPER TEXT ---\n\n{raw_text}"},
352
- ]
353
- raw = call_with_retry(messages, max_tokens=6144, temperature=0.2)
354
- return parse_llm_json(raw, "paper_analysis")
355
-
356
-
357
- def design_implementation(analysis):
358
- """Stage 2: Create implementation design from analysis."""
359
- messages = [
360
- {"role": "system", "content": SYSTEM_PROMPT},
361
- {"role": "user", "content": f"{DESIGN_PROMPT}\n\n--- PAPER ANALYSIS ---\n\n{json.dumps(analysis, indent=2)}"},
362
- ]
363
- raw = call_with_retry(messages, max_tokens=6144, temperature=0.2)
364
- return parse_llm_json(raw, "implementation_design")
365
-
366
-
367
- def generate_notebook_cells_stream(analysis, design):
368
- """
369
- Stage 3: Generate notebook cells from analysis and design.
370
- Yields tokens from the LLM for live streaming in the UI.
371
- Finally yields the parsed cells list.
372
- """
373
- prompt = GENERATE_PROMPT_TEMPLATE.format(
374
- analysis=json.dumps(analysis, indent=2),
375
- design=json.dumps(design, indent=2),
376
- )
377
- messages = [
378
- {"role": "system", "content": SYSTEM_PROMPT},
379
- {"role": "user", "content": prompt},
380
- ]
381
-
382
- # Use streaming mode
383
- stream = call_with_retry(messages, max_tokens=65536, temperature=0.3, stream=True)
384
- full_response = []
385
-
386
- for chunk in stream:
387
- if chunk.choices and chunk.choices[0].delta.content:
388
- token = chunk.choices[0].delta.content
389
- full_response.append(token)
390
- yield ("token", token)
391
-
392
- raw_text = "".join(full_response)
393
- result = parse_llm_json(raw_text, "notebook_cells")
394
-
395
- # Final logic to ensure we return a list of cells
396
- cells = []
397
- if isinstance(result, dict):
398
- cells = result.get("cells", [{"cell_type": "markdown", "source": json.dumps(result, indent=2)}])
399
- elif isinstance(result, list):
400
- cells = result
401
- else:
402
- cells = [{"cell_type": "markdown", "source": raw_text}]
403
-
404
- yield ("cells_final", cells)
405
-
406
-
407
- # ── Streaming Pipeline ─────────────────────────────────────────────────────
408
-
409
- def run_full_pipeline_stream(raw_text):
410
- """
411
- Orchestrates the full 3-stage pipeline.
412
- Yields SSE-formatted text events for the frontend code viewer.
413
- Returns final cells via the 'cells' key in the last event.
414
-
415
- Yields tuples of (event_type, data):
416
- ("text", str) β€” display text for the code viewer
417
- ("cells", list) β€” final cells (only yielded once at end)
418
- ("analysis", dict) β€” analysis metadata
419
- ("error", str) β€” error message
420
- """
421
- try:
422
- # ── Stage 1: Analyze ──
423
- yield ("text", "\n Analyzing Paper\n")
424
- yield ("text", " " + "─" * 40 + "\n\n")
425
-
426
- analysis = analyze_paper(raw_text)
427
-
428
- if not analysis:
429
- yield ("text", " Analysis returned empty. The LLM may have failed.\n\n")
430
- yield ("error", "Analysis returned empty result")
431
- return
432
-
433
- title = analysis.get("title", "Unknown Paper")
434
- field = analysis.get("research_field", "")
435
- insight = analysis.get("key_insight", "")
436
- algos = [a.get("name", "") for a in analysis.get("algorithms", [])]
437
- feynman_analogy = analysis.get("feynman_analogy", "")
438
- feynman_concept = analysis.get("feynman_core_concept", "")
439
-
440
- # Clean, minimal analysis output
441
- yield ("text", f" {title}\n")
442
- yield ("text", f" {field}\n\n")
443
-
444
- # The Feynman Explanation β€” the star of the show
445
- if feynman_analogy or feynman_concept:
446
- yield ("text", " ─── The Feynman Explanation ───\n\n")
447
- if feynman_analogy:
448
- yield ("text", f" {feynman_analogy}\n\n")
449
- if feynman_concept:
450
- yield ("text", f" {feynman_concept}\n\n")
451
-
452
- if insight:
453
- yield ("text", f" Key Insight: {insight}\n\n")
454
-
455
- yield ("text", " Analysis complete.\n\n")
456
-
457
- yield ("analysis", {
458
- "title": title,
459
- "field": field,
460
- "insight": insight,
461
- "algorithms": algos,
462
- "feynman_analogy": feynman_analogy,
463
- })
464
-
465
- # ── Stage 2: Design ──
466
- yield ("text", "\n Designing Implementation\n")
467
- yield ("text", " " + "─" * 40 + "\n\n")
468
-
469
- design = design_implementation(analysis)
470
- if not design:
471
- design = {}
472
-
473
- arch = design.get("model_architecture", {})
474
- tc = design.get("training_config", {})
475
- yield ("text", f" Architecture: {arch.get('type', 'N/A')}\n")
476
- yield ("text", f" Training: {tc.get('optimizer', 'Adam')}, lr={tc.get('learning_rate', 0.001)}, {tc.get('num_epochs', 10)} epochs\n")
477
- yield ("text", " Design complete.\n\n")
478
-
479
- # ── Stage 3: Generate (Now with LIVE STREAMING) ──
480
- yield ("text", "\n Generating Notebook (Live Streaming)\n")
481
- yield ("text", " " + "─" * 40 + "\n\n")
482
-
483
- cells = []
484
- for event_type, data in generate_notebook_cells_stream(analysis, design):
485
- if event_type == "token":
486
- # Yield raw tokens to the code viewer for "ghost-writing" effect
487
- yield ("text", data)
488
- elif event_type == "cells_final":
489
- cells = data
490
-
491
- code_cells = sum(1 for c in cells if c.get("cell_type") == "code")
492
- md_cells = sum(1 for c in cells if c.get("cell_type") == "markdown")
493
- yield ("text", f"\n\n βœ… Generation complete: {len(cells)} cells ({code_cells} code, {md_cells} markdown)\n")
494
- yield ("text", " Notebook ready for download.\n")
495
-
496
- yield ("cells", cells)
497
-
498
- except Exception as e:
499
- yield ("error", str(e))
500
-
501
-
502
- # ── Legacy compatibility ───────────────────────────────────────────────────
503
- # Keep old function signatures working for backward compatibility
504
-
505
- def extract_methodology(base64_images):
506
- """Legacy wrapper: extracts text from images."""
507
- return extract_text_from_images(base64_images)
508
-
509
-
510
- # ── Visual Illustration (FLUX.1-schnell) ───────────────────────────────────
511
-
512
- # System prompt for Qwen to craft image generation prompts
513
- IMAGE_PROMPT_SYSTEM = """You are a world-class scientific illustrator and prompt engineer.
514
- Your job: given a structured analysis of a research paper, write ONE prompt for an
515
- AI image generator (FLUX) that will produce a clear, beautiful, academic-quality
516
- visual illustration of the paper's CORE CONCEPT.
517
-
518
- Rules:
519
- 1. Focus on the MAIN IDEA β€” the central algorithm, architecture, or mechanism.
520
- 2. Describe the visual layout precisely: shapes, arrows, labels, flow direction.
521
- 3. Use academic illustration style: clean lines, labeled components, white background.
522
- 4. Include spatial relationships: "on the left", "flowing into", "surrounded by".
523
- 5. Mention color coding for different components.
524
- 6. Do NOT include text/equations in the image β€” focus on visual metaphors.
525
- 7. Keep it to ONE paragraph, 80-120 words.
526
- 8. End with style keywords: "scientific diagram, educational poster, vector style,
527
- clean layout, professional, high resolution"
528
-
529
- Return ONLY the prompt text, nothing else."""
530
-
531
- def generate_concept_image(analysis):
532
- """
533
- Generate a visual illustration of a paper's core concept.
534
- Step 1: Qwen crafts a detailed, structured prompt from the analysis.
535
- Step 2: FLUX.1-schnell generates the image.
536
- Returns base64-encoded PNG string or None on failure.
537
- """
538
- if not FLUX_API_KEY:
539
- raise RuntimeError("NVIDIA_FLUX_API_KEY not set")
540
-
541
- # ── Step 1: Qwen β†’ Image Prompt ──
542
- analysis_summary = json.dumps({
543
- "title": analysis.get("title", ""),
544
- "research_field": analysis.get("research_field") or analysis.get("field", ""),
545
- "key_insight": analysis.get("key_insight") or analysis.get("insight", ""),
546
- "algorithms": analysis.get("algorithms", []),
547
- "feynman_analogy": analysis.get("feynman_analogy", ""),
548
- "feynman_core_concept": analysis.get("feynman_core_concept", ""),
549
- }, indent=2)
550
-
551
- prompt_messages = [
552
- {"role": "system", "content": IMAGE_PROMPT_SYSTEM},
553
- {"role": "user", "content": f"Create an image generation prompt for this paper:\n\n{analysis_summary}"},
554
- ]
555
-
556
- print(" 🎨 Generating image prompt via Qwen...")
557
- image_prompt = call_with_retry(prompt_messages, max_tokens=300, temperature=0.7)
558
- if not image_prompt:
559
- raise RuntimeError("Qwen returned empty image prompt")
560
-
561
- # Add preamble for FLUX to ensure academic quality
562
- full_prompt = (
563
- "A detailed, clean scientific illustration for an academic paper. "
564
- "Style: professional educational diagram, labeled components, "
565
- "modern flat vector design, white background, high contrast, "
566
- "color-coded sections, no text. "
567
- f"{image_prompt.strip()}"
568
- )
569
- print(f" πŸ“ FLUX prompt ({len(full_prompt)} chars): {full_prompt[:100]}...")
570
-
571
- # ── Step 2: FLUX.1-schnell β†’ Image ──
572
- print(" πŸ–ΌοΈ Calling FLUX.1-schnell...")
573
- headers = {
574
- "Authorization": f"Bearer {FLUX_API_KEY}",
575
- "Content-Type": "application/json",
576
- "Accept": "application/json",
577
- }
578
- payload = {
579
- "prompt": full_prompt,
580
- "height": 1024,
581
- "width": 1024,
582
- "num_inference_steps": 4,
583
- "guidance_scale": 0.0,
584
- }
585
-
586
- response = requests.post(FLUX_API_URL, headers=headers, json=payload, timeout=60)
587
-
588
- if response.status_code != 200:
589
- raise RuntimeError(f"FLUX API error {response.status_code}: {response.text[:200]}")
590
-
591
- result = response.json()
592
- # FLUX returns {"image": "base64..."} or {"artifacts": [{"base64": "..."}]}
593
- image_b64 = None
594
- if "image" in result:
595
- image_b64 = result["image"]
596
- elif "artifacts" in result and len(result["artifacts"]) > 0:
597
- image_b64 = result["artifacts"][0].get("base64", "")
598
-
599
- if not image_b64:
600
- raise RuntimeError("FLUX returned no image data")
601
-
602
- print(f" βœ… Image generated ({len(image_b64)} chars base64)")
603
- return image_b64
 
 
 
 
 
 
1
+ """
2
+ Pundit Feynman LLM Client β€” 3-Stage Pipeline
3
+ Stage 1: Analyze (images β†’ structured JSON analysis)
4
+ Stage 2: Design (analysis β†’ implementation plan JSON)
5
+ Stage 3: Generate (analysis + design β†’ notebook cells JSON)
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import time
11
+ import re
12
+ import requests
13
+ from openai import OpenAI
14
+ from dotenv import load_dotenv
15
+
16
+ load_dotenv()
17
+
18
+ # ── Configuration ──────────────────────────────────────────────────────────
19
+ API_KEY = os.getenv("NVIDIA_API_KEY", "")
20
+ BASE_URL = os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1")
21
+ MODEL = os.getenv("LLM_MODEL", "qwen/qwen3.5-397b-a17b")
22
+ MAX_IMAGES_PER_REQUEST = int(os.getenv("MAX_IMAGES_PER_REQUEST", "8"))
23
+
24
+ # OCR Configuration
25
+ OCR_API_KEY = os.getenv("NVIDIA_OCR_API_KEY", "")
26
+ OCR_API_URL = "https://ai.api.nvidia.com/v1/cv/nvidia/nemoretriever-ocr-v1"
27
+
28
+ # FLUX.1-schnell Image Generation
29
+ FLUX_API_KEY = os.getenv("NVIDIA_FLUX_API_KEY", "")
30
+ FLUX_API_URL = "https://ai.api.nvidia.com/v1/genai/black-forest-labs/flux.1-schnell"
31
+
32
+ MAX_RETRIES = 3
33
+ RETRY_DELAYS = [5, 15, 30]
34
+
35
+ client = OpenAI(
36
+ base_url=BASE_URL,
37
+ api_key=API_KEY,
38
+ timeout=600.0, # Explicit default timeout for the client
39
+ )
40
+
41
+
42
+ # ── Prompts ────────────────────────────────────────────────────────────────
43
+
44
+ SYSTEM_PROMPT = (
45
+ "You are an expert research engineer and educator who converts academic papers into "
46
+ "clear, educational, executable Python code. You produce structured JSON output for "
47
+ "each stage of the pipeline. When building toy implementations, you create REAL working code "
48
+ "(PyTorch, Transformer layers, actual training loops) at reduced scale that "
49
+ "runs on CPU. You prioritize faithful replication of the paper's architecture "
50
+ "and algorithms while making the code deeply educational with clear explanations, "
51
+ "using the Feynman technique to break down complex math into simple analogies, "
52
+ "verbose logging, and insightful visualizations."
53
+ )
54
+
55
+ ANALYSIS_PROMPT = """Analyze this research paper text and return a JSON object with:
56
+ {
57
+ "title": "exact paper title",
58
+ "authors": ["author names"],
59
+ "research_field": "e.g. NLP, Computer Vision, RL",
60
+ "abstract_summary": "2-3 sentence plain English summary of the paper",
61
+ "feynman_analogy": "A brilliant, everyday analogy that maps perfectly to the paper's core key_insight (e.g., comparing attention mechanisms to a cocktail party)",
62
+ "feynman_core_concept": "Explain the paper's main idea as if teaching a bright 12-year-old, using the analogy above, in 3-5 sentences",
63
+ "key_insight": "the core novel contribution in one sentence",
64
+ "algorithms": [
65
+ {
66
+ "name": "algorithm name",
67
+ "purpose": "what it does",
68
+ "key_equations": ["important formulas in LaTeX notation"],
69
+ "pseudocode_steps": ["step1", "step2"]
70
+ }
71
+ ],
72
+ "architecture": {
73
+ "type": "e.g. Transformer, CNN, GAN",
74
+ "components": ["list of main components"],
75
+ "data_flow": "description of how data flows through the model"
76
+ },
77
+ "datasets_mentioned": ["dataset names"],
78
+ "implementation_requirements": {
79
+ "frameworks": ["PyTorch"],
80
+ "key_hyperparameters": {"param": "value"},
81
+ "estimated_complexity": "low/medium/high for toy version"
82
+ }
83
+ }
84
+
85
+ Return ONLY valid JSON, no markdown, no extra text."""
86
+
87
+ DESIGN_PROMPT = """Based on this paper analysis, create a toy implementation design that runs on CPU.
88
+ Return a JSON object with:
89
+ {
90
+ "model_architecture": {
91
+ "type": "architecture type",
92
+ "embed_dim": 64,
93
+ "num_layers": 2,
94
+ "num_heads": 4,
95
+ "vocab_size": 1000,
96
+ "max_seq_len": 64,
97
+ "components": [
98
+ {
99
+ "name": "component name",
100
+ "class_name": "PythonClassName",
101
+ "description": "what this component does",
102
+ "key_params": {"param": "value"}
103
+ }
104
+ ]
105
+ },
106
+ "training_config": {
107
+ "optimizer": "Adam",
108
+ "learning_rate": 0.001,
109
+ "num_epochs": 5,
110
+ "batch_size": 16,
111
+ "loss_function": "CrossEntropyLoss",
112
+ "dataset_strategy": "synthetic generation approach"
113
+ },
114
+ "visualization_plan": [
115
+ "loss curve",
116
+ "attention heatmap",
117
+ "sample predictions"
118
+ ],
119
+ "estimated_cells": 15,
120
+ "code_structure": [
121
+ {"section": "imports", "description": "required libraries"},
122
+ {"section": "model", "description": "model architecture classes"},
123
+ {"section": "data", "description": "synthetic data generation"},
124
+ {"section": "training", "description": "training loop"},
125
+ {"section": "evaluation", "description": "testing and visualization"}
126
+ ]
127
+ }
128
+
129
+ Return ONLY valid JSON, no markdown, no extra text."""
130
+
131
+ GENERATE_PROMPT_TEMPLATE = """You are generating a Jupyter notebook from a paper analysis and implementation design.
132
+ Analysis: {analysis}
133
+ Design: {design}
134
+
135
+ Note: You are a 397B parameter model (Qwen 3.5) with 17B actively used parameters (MoE architecture).
136
+ This means you have deep expertise and vast knowledge. Use it to produce genuinely educational content.
137
+
138
+ Return a JSON array of notebook cells following this **exact 13-section structure**:
139
+
140
+ 1. **Title & Overview** (markdown) β€” Paper title, authors, a one-paragraph summary of the paper.
141
+
142
+ 2. **Table of Contents** (markdown) β€” Numbered list of all 13 sections. Each section name should be a clickable anchor link.
143
+
144
+ 3. **The Feynman Explanation** (markdown) β€” A step-by-step explanation of the WHOLE paper using the Feynman technique. Break down the core algorithms, math, and architecture into the absolute simplest terms possible. Expand heavily on the `feynman_analogy` and `feynman_core_concept` from the analysis. Use relatable, everyday analogies for each major step so a beginner can intuitively grasp how the system works before seeing the code.
145
+
146
+ 4. **Environment Setup** (code) β€” pip installs and imports. Include `torch`, `numpy`, `matplotlib`, and any other needed libraries.
147
+
148
+ 5. **Configuration & Hyperparameters** (code) β€” A single config dict or dataclass with all hyperparameters. Add comments explaining each.
149
+
150
+ 6. **Data Preparation** (code) β€” Synthetic dataset generation or loading. Must produce realistic dummy data matching the paper's domain.
151
+
152
+ 7. **Model Architecture** (code) β€” Full PyTorch model implementation. Use `nn.Module` subclasses with detailed docstrings about each component. Include shape comments.
153
+
154
+ 8. **Training Loop** (code) β€” Complete training loop with loss tracking, progress printing, and gradient clipping.
155
+
156
+ 9. **Training Execution** (code) β€” Run the training and display results.
157
+
158
+ 10. **Evaluation & Metrics** (code) β€” Run inference on test data and compute relevant metrics.
159
+
160
+ 11. **Visualizations** (code) β€” Matplotlib charts: loss curves, attention heatmaps or feature maps, sample predictions.
161
+
162
+ 12. **Key Takeaways** (markdown) β€” Bullet-point summary of what was learned, what would change at full scale, potential improvements.
163
+
164
+ 13. **References** (markdown) β€” Paper citation, related work links, library documentation links.
165
+
166
+ Each cell in the JSON array must have:
167
+ {{"cell_type": "code" or "markdown", "source": "cell content as a string"}}
168
+
169
+ RULES:
170
+ - All code must be executable on CPU
171
+ - Use educational variable names and heavy commenting
172
+ - Include print() statements showing tensor shapes and intermediate results
173
+ - Follow the 13-section structure exactly
174
+ - Minimum 15 cells total
175
+ - The Feynman Explanation should be at least 300 words
176
+ - Return ONLY the JSON array, no markdown fences"""
177
+
178
+
179
+ # ── OCR extraction (NVIDIA NeMo Retriever OCR v1) ─────────────────────────
180
+
181
+ def extract_text_from_images(base64_images):
182
+ """Extract text from paper page images using NVIDIA NeMo Retriever OCR API.
183
+ Sends page images to the dedicated OCR model for fast, accurate extraction.
184
+ Falls back to page-by-page if a batch request fails.
185
+ """
186
+ all_text = []
187
+ headers = {
188
+ "Authorization": f"Bearer {OCR_API_KEY}",
189
+ "Accept": "application/json",
190
+ "Content-Type": "application/json",
191
+ }
192
+
193
+ total = len(base64_images)
194
+ print(f" OCR: Processing {total} pages via NVIDIA NeMo Retriever...")
195
+
196
+ for page_idx, img_b64 in enumerate(base64_images):
197
+ print(f" Page {page_idx + 1}/{total}...")
198
+
199
+ payload = {
200
+ "input": [
201
+ {
202
+ "type": "image_url",
203
+ "url": f"data:image/jpeg;base64,{img_b64}"
204
+ }
205
+ ],
206
+ "merge_levels": ["paragraph"]
207
+ }
208
+
209
+ try:
210
+ resp = requests.post(
211
+ OCR_API_URL,
212
+ headers=headers,
213
+ json=payload,
214
+ timeout=60,
215
+ )
216
+ resp.raise_for_status()
217
+ result = resp.json()
218
+
219
+ # Extract text from OCR response
220
+ page_text = _parse_ocr_response(result, page_idx + 1)
221
+ if page_text:
222
+ all_text.append(page_text)
223
+
224
+ except Exception as e:
225
+ print(f" \u26a0 OCR failed for page {page_idx + 1}: {e}")
226
+ # Continue with remaining pages
227
+ continue
228
+
229
+ if not all_text:
230
+ raise RuntimeError("OCR failed: No text extracted from any page")
231
+
232
+ combined = "\n\n".join(all_text)
233
+ print(f" OCR complete: {len(combined)} chars from {len(all_text)}/{total} pages")
234
+ return combined
235
+
236
+
237
+ def _parse_ocr_response(response_json, page_num):
238
+ """Parse the NVIDIA OCR API response into clean text.
239
+ Response format: {"data": [{"text_detections": [{"text_prediction": {"text": ..., "confidence": ...}}]}]}
240
+ """
241
+ texts = []
242
+ try:
243
+ for item in response_json.get("data", []):
244
+ for detection in item.get("text_detections", []):
245
+ pred = detection.get("text_prediction", {})
246
+ text = pred.get("text", "").strip()
247
+ confidence = pred.get("confidence", 0)
248
+ # Only include text with reasonable confidence
249
+ if text and confidence > 0.3:
250
+ texts.append(text)
251
+ except Exception as e:
252
+ print(f" \u26a0 Error parsing OCR response for page {page_num}: {e}")
253
+ return ""
254
+
255
+ return "\n".join(texts)
256
+
257
+
258
+ # ── LLM Call with Retry ───────────────────────────────────────────────────
259
+
260
+ def call_with_retry(messages, max_tokens=4096, temperature=0.3, stream=False):
261
+ """Call the LLM API with retry logic for transient errors."""
262
+ last_error = None
263
+
264
+ for attempt in range(MAX_RETRIES):
265
+ try:
266
+ kwargs = dict(
267
+ model=MODEL,
268
+ messages=messages,
269
+ max_tokens=max_tokens,
270
+ temperature=temperature,
271
+ timeout=300,
272
+ )
273
+ if stream:
274
+ kwargs["stream"] = True
275
+ return client.chat.completions.create(**kwargs)
276
+ else:
277
+ response = client.chat.completions.create(**kwargs)
278
+ return response.choices[0].message.content
279
+
280
+ except Exception as e:
281
+ error_str = str(e).lower()
282
+ # Include "timeout" and "timed out" in retryable errors
283
+ if any(kw in error_str for kw in ["429", "rate", "500", "503", "overloaded", "unavailable", "timeout", "timed out"]):
284
+ last_error = e
285
+ wait = RETRY_DELAYS[min(attempt, len(RETRY_DELAYS) - 1)]
286
+ print(f" ⚠ Transient error: {e}. Waiting {wait}s before retry {attempt + 1}/{MAX_RETRIES}...")
287
+ time.sleep(wait)
288
+ else:
289
+ raise
290
+
291
+ raise RuntimeError(f"Failed after {MAX_RETRIES} retries. Last error: {last_error}")
292
+
293
+
294
+ # ── JSON Parsing ──────────────────────────────────────────────────────────
295
+
296
+ def parse_llm_json(raw_text, step_name):
297
+ """Parse JSON from LLM response, with cleanup and one repair attempt."""
298
+ if raw_text is None:
299
+ print(f" ⚠ LLM returned None for {step_name}")
300
+ return {}
301
+ text = raw_text.strip()
302
+
303
+ # Strip markdown code fences if present
304
+ if text.startswith("```"):
305
+ first_newline = text.index("\n")
306
+ text = text[first_newline + 1:]
307
+ if text.endswith("```"):
308
+ text = text[:-3]
309
+ text = text.strip()
310
+
311
+ # Try direct parse
312
+ try:
313
+ return json.loads(text)
314
+ except json.JSONDecodeError as e:
315
+ print(f" ⚠ JSON parse failed in {step_name}. Attempting repair...")
316
+
317
+ # Attempt auto-repair via LLM
318
+ repair_prompt = (
319
+ f"The following text was supposed to be valid JSON but has a syntax error:\n\n"
320
+ f"{text[:6000]}\n\n"
321
+ f"Error: {e}\n\n"
322
+ f"Return ONLY the corrected valid JSON, nothing else."
323
+ )
324
+ repaired = call_with_retry(
325
+ messages=[
326
+ {"role": "system", "content": "You are a JSON repair tool. Return only valid JSON."},
327
+ {"role": "user", "content": repair_prompt},
328
+ ],
329
+ max_tokens=max(len(text) // 2, 4096),
330
+ temperature=0.1,
331
+ )
332
+ if repaired is None:
333
+ raise ValueError(f"Could not repair JSON from {step_name} β€” LLM returned None")
334
+ repaired = repaired.strip()
335
+ if repaired.startswith("```"):
336
+ repaired = repaired.split("\n", 1)[1]
337
+ if repaired.endswith("```"):
338
+ repaired = repaired[:-3]
339
+
340
+ try:
341
+ return json.loads(repaired.strip())
342
+ except json.JSONDecodeError:
343
+ # Last resort: try to extract JSON from the text
344
+ json_match = re.search(r'[\[{].*[\]}]', repaired.strip(), re.DOTALL)
345
+ if json_match:
346
+ return json.loads(json_match.group())
347
+ raise ValueError(f"Could not parse JSON from {step_name} even after repair.")
348
+
349
+
350
+ # ── Pipeline Stages ───────────────────────────────────────────────────────
351
+
352
+ def analyze_paper(raw_text):
353
+ """Stage 1: Analyze extracted text into structured JSON."""
354
+ messages = [
355
+ {"role": "system", "content": SYSTEM_PROMPT},
356
+ {"role": "user", "content": f"{ANALYSIS_PROMPT}\n\n--- EXTRACTED PAPER TEXT ---\n\n{raw_text}"},
357
+ ]
358
+ raw = call_with_retry(messages, max_tokens=6144, temperature=0.2)
359
+ return parse_llm_json(raw, "paper_analysis")
360
+
361
+
362
+ def design_implementation(analysis):
363
+ """Stage 2: Create implementation design from analysis."""
364
+ messages = [
365
+ {"role": "system", "content": SYSTEM_PROMPT},
366
+ {"role": "user", "content": f"{DESIGN_PROMPT}\n\n--- PAPER ANALYSIS ---\n\n{json.dumps(analysis, indent=2)}"},
367
+ ]
368
+ raw = call_with_retry(messages, max_tokens=6144, temperature=0.2)
369
+ return parse_llm_json(raw, "implementation_design")
370
+
371
+
372
+ def generate_notebook_cells_stream(analysis, design):
373
+ """
374
+ Stage 3: Generate notebook cells from analysis and design.
375
+ Yields tokens from the LLM for live streaming in the UI.
376
+ Finally yields the parsed cells list.
377
+ """
378
+ prompt = GENERATE_PROMPT_TEMPLATE.format(
379
+ analysis=json.dumps(analysis, indent=2),
380
+ design=json.dumps(design, indent=2),
381
+ )
382
+ messages = [
383
+ {"role": "system", "content": SYSTEM_PROMPT},
384
+ {"role": "user", "content": prompt},
385
+ ]
386
+
387
+ # Use streaming mode
388
+ stream = call_with_retry(messages, max_tokens=65536, temperature=0.3, stream=True)
389
+ full_response = []
390
+
391
+ for chunk in stream:
392
+ if chunk.choices and chunk.choices[0].delta.content:
393
+ token = chunk.choices[0].delta.content
394
+ full_response.append(token)
395
+ yield ("token", token)
396
+
397
+ raw_text = "".join(full_response)
398
+ result = parse_llm_json(raw_text, "notebook_cells")
399
+
400
+ # Final logic to ensure we return a list of cells
401
+ cells = []
402
+ if isinstance(result, dict):
403
+ cells = result.get("cells", [{"cell_type": "markdown", "source": json.dumps(result, indent=2)}])
404
+ elif isinstance(result, list):
405
+ cells = result
406
+ else:
407
+ cells = [{"cell_type": "markdown", "source": raw_text}]
408
+
409
+ yield ("cells_final", cells)
410
+
411
+
412
+ # ── Streaming Pipeline ─────────────────────────────────────────────────────
413
+
414
+ def run_full_pipeline_stream(raw_text):
415
+ """
416
+ Orchestrates the full 3-stage pipeline.
417
+ Yields SSE-formatted text events for the frontend code viewer.
418
+ Returns final cells via the 'cells' key in the last event.
419
+
420
+ Yields tuples of (event_type, data):
421
+ ("text", str) β€” display text for the code viewer
422
+ ("cells", list) β€” final cells (only yielded once at end)
423
+ ("analysis", dict) β€” analysis metadata
424
+ ("error", str) β€” error message
425
+ """
426
+ try:
427
+ # ── Stage 1: Analyze ──
428
+ yield ("text", "\n Analyzing Paper\n")
429
+ yield ("text", " " + "─" * 40 + "\n\n")
430
+
431
+ analysis = analyze_paper(raw_text)
432
+
433
+ if not analysis:
434
+ yield ("text", " Analysis returned empty. The LLM may have failed.\n\n")
435
+ yield ("error", "Analysis returned empty result")
436
+ return
437
+
438
+ title = analysis.get("title", "Unknown Paper")
439
+ field = analysis.get("research_field", "")
440
+ insight = analysis.get("key_insight", "")
441
+ algos = [a.get("name", "") for a in analysis.get("algorithms", [])]
442
+ feynman_analogy = analysis.get("feynman_analogy", "")
443
+ feynman_concept = analysis.get("feynman_core_concept", "")
444
+
445
+ # Clean, minimal analysis output
446
+ yield ("text", f" {title}\n")
447
+ yield ("text", f" {field}\n\n")
448
+
449
+ # The Feynman Explanation β€” the star of the show
450
+ if feynman_analogy or feynman_concept:
451
+ yield ("text", " ─── The Feynman Explanation ───\n\n")
452
+ if feynman_analogy:
453
+ yield ("text", f" {feynman_analogy}\n\n")
454
+ if feynman_concept:
455
+ yield ("text", f" {feynman_concept}\n\n")
456
+
457
+ if insight:
458
+ yield ("text", f" Key Insight: {insight}\n\n")
459
+
460
+ yield ("text", " Analysis complete.\n\n")
461
+
462
+ yield ("analysis", {
463
+ "title": title,
464
+ "field": field,
465
+ "insight": insight,
466
+ "algorithms": algos,
467
+ "feynman_analogy": feynman_analogy,
468
+ })
469
+
470
+ # ── Stage 2: Design ──
471
+ yield ("text", "\n Designing Implementation\n")
472
+ yield ("text", " " + "─" * 40 + "\n\n")
473
+
474
+ design = design_implementation(analysis)
475
+ if not design:
476
+ design = {}
477
+
478
+ arch = design.get("model_architecture", {})
479
+ tc = design.get("training_config", {})
480
+ yield ("text", f" Architecture: {arch.get('type', 'N/A')}\n")
481
+ yield ("text", f" Training: {tc.get('optimizer', 'Adam')}, lr={tc.get('learning_rate', 0.001)}, {tc.get('num_epochs', 10)} epochs\n")
482
+ yield ("text", " Design complete.\n\n")
483
+
484
+ # ── Stage 3: Generate (Now with LIVE STREAMING) ──
485
+ yield ("text", "\n Generating Notebook (Live Streaming)\n")
486
+ yield ("text", " " + "─" * 40 + "\n\n")
487
+
488
+ cells = []
489
+ for event_type, data in generate_notebook_cells_stream(analysis, design):
490
+ if event_type == "token":
491
+ # Yield raw tokens to the code viewer for "ghost-writing" effect
492
+ yield ("text", data)
493
+ elif event_type == "cells_final":
494
+ cells = data
495
+
496
+ code_cells = sum(1 for c in cells if c.get("cell_type") == "code")
497
+ md_cells = sum(1 for c in cells if c.get("cell_type") == "markdown")
498
+ yield ("text", f"\n\n βœ… Generation complete: {len(cells)} cells ({code_cells} code, {md_cells} markdown)\n")
499
+ yield ("text", " Notebook ready for download.\n")
500
+
501
+ yield ("cells", cells)
502
+
503
+ except Exception as e:
504
+ yield ("error", str(e))
505
+
506
+
507
+ # ── Legacy compatibility ───────────────────────────────────────────────────
508
+ # Keep old function signatures working for backward compatibility
509
+
510
+ def extract_methodology(base64_images):
511
+ """Legacy wrapper: extracts text from images."""
512
+ return extract_text_from_images(base64_images)
513
+
514
+
515
+ # ── Visual Illustration (FLUX.1-schnell) ───────────────────────────────────
516
+
517
+ # System prompt for Qwen to craft image generation prompts
518
+ IMAGE_PROMPT_SYSTEM = """You are a world-class scientific illustrator and prompt engineer.
519
+ Your job: given a structured analysis of a research paper, write ONE prompt for an
520
+ AI image generator (FLUX) that will produce a clear, beautiful, academic-quality
521
+ visual illustration of the paper's CORE CONCEPT.
522
+
523
+ Rules:
524
+ 1. Focus on the MAIN IDEA β€” the central algorithm, architecture, or mechanism.
525
+ 2. Describe the visual layout precisely: shapes, arrows, labels, flow direction.
526
+ 3. Use academic illustration style: clean lines, labeled components, white background.
527
+ 4. Include spatial relationships: "on the left", "flowing into", "surrounded by".
528
+ 5. Mention color coding for different components.
529
+ 6. Do NOT include text/equations in the image β€” focus on visual metaphors.
530
+ 7. Keep it to ONE paragraph, 80-120 words.
531
+ 8. End with style keywords: "scientific diagram, educational poster, vector style,
532
+ clean layout, professional, high resolution"
533
+
534
+ Return ONLY the prompt text, nothing else."""
535
+
536
+ def generate_concept_image(analysis):
537
+ """
538
+ Generate a visual illustration of a paper's core concept.
539
+ Step 1: Qwen crafts a detailed, structured prompt from the analysis.
540
+ Step 2: FLUX.1-schnell generates the image.
541
+ Returns base64-encoded PNG string or None on failure.
542
+ """
543
+ if not FLUX_API_KEY:
544
+ raise RuntimeError("NVIDIA_FLUX_API_KEY not set")
545
+
546
+ # ── Step 1: Qwen β†’ Image Prompt ──
547
+ analysis_summary = json.dumps({
548
+ "title": analysis.get("title", ""),
549
+ "research_field": analysis.get("research_field") or analysis.get("field", ""),
550
+ "key_insight": analysis.get("key_insight") or analysis.get("insight", ""),
551
+ "algorithms": analysis.get("algorithms", []),
552
+ "feynman_analogy": analysis.get("feynman_analogy", ""),
553
+ "feynman_core_concept": analysis.get("feynman_core_concept", ""),
554
+ }, indent=2)
555
+
556
+ prompt_messages = [
557
+ {"role": "system", "content": IMAGE_PROMPT_SYSTEM},
558
+ {"role": "user", "content": f"Create an image generation prompt for this paper:\n\n{analysis_summary}"},
559
+ ]
560
+
561
+ print(" 🎨 Generating image prompt via Qwen...")
562
+ image_prompt = call_with_retry(prompt_messages, max_tokens=300, temperature=0.7)
563
+ if not image_prompt:
564
+ raise RuntimeError("Qwen returned empty image prompt")
565
+
566
+ # Add preamble for FLUX to ensure academic quality
567
+ full_prompt = (
568
+ "A detailed, clean scientific illustration for an academic paper. "
569
+ "Style: professional educational diagram, labeled components, "
570
+ "modern flat vector design, white background, high contrast, "
571
+ "color-coded sections, no text. "
572
+ f"{image_prompt.strip()}"
573
+ )
574
+ print(f" πŸ“ FLUX prompt ({len(full_prompt)} chars): {full_prompt[:100]}...")
575
+
576
+ # ── Step 2: FLUX.1-schnell β†’ Image ──
577
+ print(" πŸ–ΌοΈ Calling FLUX.1-schnell...")
578
+ headers = {
579
+ "Authorization": f"Bearer {FLUX_API_KEY}",
580
+ "Content-Type": "application/json",
581
+ "Accept": "application/json",
582
+ }
583
+ payload = {
584
+ "prompt": full_prompt,
585
+ "height": 1024,
586
+ "width": 1024,
587
+ "num_inference_steps": 4,
588
+ "guidance_scale": 0.0,
589
+ }
590
+
591
+ response = requests.post(FLUX_API_URL, headers=headers, json=payload, timeout=60)
592
+
593
+ if response.status_code != 200:
594
+ raise RuntimeError(f"FLUX API error {response.status_code}: {response.text[:200]}")
595
+
596
+ result = response.json()
597
+ # FLUX returns {"image": "base64..."} or {"artifacts": [{"base64": "..."}]}
598
+ image_b64 = None
599
+ if "image" in result:
600
+ image_b64 = result["image"]
601
+ elif "artifacts" in result and len(result["artifacts"]) > 0:
602
+ image_b64 = result["artifacts"][0].get("base64", "")
603
+
604
+ if not image_b64:
605
+ raise RuntimeError("FLUX returned no image data")
606
+
607
+ print(f" βœ… Image generated ({len(image_b64)} chars base64)")
608
+ return image_b64