Ali Hashhash commited on
Commit
1d88d91
Β·
1 Parent(s): e0ffc4f

feat: add note_generator module to handle automated summarization tasks

Browse files
Files changed (1) hide show
  1. src/summarization/note_generator.py +384 -17
src/summarization/note_generator.py CHANGED
@@ -1,6 +1,8 @@
1
  import json
2
  import os
3
- from typing import Dict, Optional
 
 
4
 
5
  from groq import Groq
6
  from pydantic import ValidationError
@@ -13,7 +15,23 @@ logger = setup_logger(__name__)
13
 
14
 
15
  # ─────────────────────────────────────────────────────────────────────────────
16
- # PROMPT TEMPLATES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ─────────────────────────────────────────────────────────────────────────────
18
 
19
  _SUMMARY_SYSTEM = """
@@ -76,6 +94,102 @@ Return ONLY the exact JSON structure requested.
76
  """.strip()
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # ─────────────────────────────────────────────────────────────────────────────
80
  # LANGUAGE LABELS (simplified)
81
  # ─────────────────────────────────────────────────────────────────────────────
@@ -105,23 +219,122 @@ def _labels(language: str) -> dict:
105
  return _LABELS.get(language, _LABELS["English"])
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # ─────────────────────────────────────────────────────────────────────────────
109
  # NOTE GENERATOR
110
  # ─────────────────────────────────────────────────────────────────────────────
111
 
112
  class NoteGenerator:
113
- """Generates structured study notes using Groq (Llama-3.3-70b-versatile)."""
 
 
 
 
 
 
 
 
114
 
115
  def __init__(self):
116
  self.api_key = os.environ.get("GROQ_API_KEY", "").strip()
117
  self.client = Groq(api_key=self.api_key) if self.api_key else None
118
- self.model_id = "llama-3.3-70b-versatile"
119
- logger.info(f"πŸš€ NoteGenerator v4.0 initialized β€” model: {self.model_id}")
 
 
 
 
 
 
 
 
 
120
 
121
- def _chat(self, system: str, user: str, max_tokens: int = 4096) -> Optional[str]:
 
 
 
 
 
 
 
 
122
  try:
123
  response = self.client.chat.completions.create(
124
- model=self.model_id,
125
  max_tokens=max_tokens,
126
  temperature=0.3,
127
  response_format={"type": "json_object"},
@@ -132,9 +345,11 @@ class NoteGenerator:
132
  )
133
  return response.choices[0].message.content
134
  except Exception as e:
135
- logger.error(f"❌ Groq API call failed: {e}")
136
  return None
137
 
 
 
138
  def _get_error_json(self, error_msg: str) -> Dict:
139
  return {
140
  "title": "Error in Generation",
@@ -145,29 +360,181 @@ class NoteGenerator:
145
  "topics": [],
146
  }
147
 
148
- def generateSummary(self, transcript_text: str, video_title: str) -> Dict:
149
- """Generate structured JSON summary from transcript."""
150
- if not self.client:
151
- return self._get_error_json("Groq API Key missing.")
 
152
 
153
- logger.info(f"πŸ“ Summary generation started via {self.model_id}")
154
  user_prompt = _SUMMARY_USER.format(
155
  video_title=video_title,
156
- transcript=transcript_text[:30000],
157
  )
158
 
159
  raw = self._chat(_SUMMARY_SYSTEM, user_prompt, max_tokens=4096)
160
  if raw is None:
161
- return self._get_error_json("Groq API call failed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  try:
164
- data = json.loads(raw)
165
  validated = SummarySchema(**data)
166
  return validated.model_dump()
167
  except (json.JSONDecodeError, ValidationError) as e:
168
- logger.error(f"❌ Schema validation failed: {e}")
169
  return self._get_error_json(f"Validation Error: {str(e)}")
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def format_notes_to_markdown(self, json_notes: Dict) -> str:
172
  """Convert JSON notes to clean Markdown β€” Summary β†’ Timeline β†’ Conclusion."""
173
  lang = json_notes.get("detected_language", "English")
 
1
  import json
2
  import os
3
+ import re
4
+ import time
5
+ from typing import Dict, List, Optional
6
 
7
  from groq import Groq
8
  from pydantic import ValidationError
 
15
 
16
 
17
  # ─────────────────────────────────────────────────────────────────────────────
18
+ # CONFIGURATION
19
+ # ─────────────────────────────────────────────────────────────────────────────
20
+
21
+ # Token threshold: below this, a single API call is used.
22
+ _SINGLE_PASS_TOKEN_LIMIT = 8_000
23
+
24
+ # Target chunk size for MAP phase (tokens). Leaves room for prompt + response
25
+ # within the 12K TPM free-tier limit.
26
+ _CHUNK_TARGET_TOKENS = 6_000
27
+
28
+ # Models
29
+ _MODEL_PRIMARY = "llama-3.3-70b-versatile" # REDUCE phase + single-pass
30
+ _MODEL_MAP = "llama-3.1-8b-instant" # MAP phase (fast, cheap)
31
+
32
+
33
+ # ─────────────────────────────────────────────────────────────────────────────
34
+ # PROMPT TEMPLATES β€” SINGLE-PASS (unchanged)
35
  # ─────────────────────────────────────────────────────────────────────────────
36
 
37
  _SUMMARY_SYSTEM = """
 
94
  """.strip()
95
 
96
 
97
+ # ─────────────────────────────────────────────────────────────────────────────
98
+ # PROMPT TEMPLATES β€” MAP PHASE
99
+ # ─────────────────────────────────────────────────────────────────────────────
100
+
101
+ _MAP_SYSTEM = """
102
+ You are an expert educational content analyst.
103
+ You will receive ONE CHUNK of a longer video transcript.
104
+ Extract the key information from this chunk ONLY.
105
+
106
+ LANGUAGE RULE β€” CRITICAL:
107
+ - Detect the primary language of the text.
108
+ - Write ALL content fields in that SAME detected language.
109
+ - Only "detected_language" is stated in English.
110
+
111
+ Return a JSON object with this EXACT structure:
112
+ {
113
+ "detected_language": "English (or Arabic, etc.)",
114
+ "chunk_summary": "Concise summary of this chunk (3-5 sentences)",
115
+ "key_points": [
116
+ {
117
+ "title": "Short title for this point",
118
+ "detail": "1-2 sentence explanation",
119
+ "insight": "Key takeaway"
120
+ }
121
+ ],
122
+ "topics": ["Topic1", "Topic2"]
123
+ }
124
+
125
+ RULES:
126
+ - Extract 2-4 key points from this chunk.
127
+ - Topics should be specific (e.g. "Python", "Neural Networks"), not generic.
128
+ - OUTPUT: Return ONLY a valid JSON object. No markdown fences, no extra text.
129
+ """.strip()
130
+
131
+ _MAP_USER = """
132
+ Video Title: {video_title}
133
+ Chunk {chunk_index} of {total_chunks}:
134
+
135
+ {chunk_text}
136
+
137
+ Extract the key information from this chunk. Return ONLY the JSON.
138
+ """.strip()
139
+
140
+
141
+ # ─────────────────────────────────────────────────────────────────────────────
142
+ # PROMPT TEMPLATES β€” REDUCE PHASE
143
+ # ─────────────────────────────────────────────────────────────────────────────
144
+
145
+ _REDUCE_SYSTEM = """
146
+ You are an expert educational content analyst and structured note-taking specialist.
147
+ You will receive INTERMEDIATE SUMMARIES from multiple chunks of a single video transcript.
148
+ Your job is to MERGE them into ONE final, cohesive, structured summary.
149
+
150
+ LANGUAGE RULE β€” CRITICAL, NEVER VIOLATE:
151
+ - Use the detected language from the intermediate summaries.
152
+ - Every content field MUST be in that SAME language.
153
+ - Only "detected_language" is stated in English.
154
+
155
+ TIMELINE RULES β€” STRICTLY ENFORCED:
156
+ - Merge the chunk summaries into 3-7 chronological segments.
157
+ - Each segment MUST cover a distinct phase or theme; do NOT repeat topics.
158
+ - Segments must follow the natural progression of the video.
159
+ - Each segment must include: title, summary, key_insight, why_it_matters.
160
+
161
+ CRITICAL: RETURN A JSON OBJECT EXACTLY MATCHING THIS STRUCTURE.
162
+ {
163
+ "title": "Inferred video title in transcript language",
164
+ "detected_language": "English (or Arabic, etc.)",
165
+ "summary": "Concise overall summary (3-5 sentences)",
166
+ "segments": [
167
+ {
168
+ "title": "Segment title",
169
+ "summary": "What this section covers (2-3 sentences)",
170
+ "key_insight": "Most important point from this section",
171
+ "why_it_matters": "Why this is valuable (1-2 sentences)"
172
+ }
173
+ ],
174
+ "conclusion": "Final overall takeaway / closing conclusion",
175
+ "topics": ["Topic1", "Topic2", "Topic3"]
176
+ }
177
+
178
+ OUTPUT: Return ONLY a valid JSON object. No markdown fences, no extra text.
179
+ """.strip()
180
+
181
+ _REDUCE_USER = """
182
+ Video Title: {video_title}
183
+
184
+ The following are intermediate summaries extracted from {total_chunks} consecutive chunks
185
+ of the video transcript. Merge them into ONE cohesive final summary.
186
+
187
+ {merged_summaries}
188
+
189
+ Merge into 3-7 chronological segments. Return ONLY the final JSON structure.
190
+ """.strip()
191
+
192
+
193
  # ─────────────────────────────────────────────────────────────────────────────
194
  # LANGUAGE LABELS (simplified)
195
  # ─────────────────────────────────────────────────────────────────────────────
 
219
  return _LABELS.get(language, _LABELS["English"])
220
 
221
 
222
+ # ─────────────────────────────────────────────────────────────────────────────
223
+ # TOKEN UTILITIES
224
+ # ─────────────────────────────────────────────────────────────────────────────
225
+
226
+ def _estimate_tokens(text: str) -> int:
227
+ """
228
+ Lightweight token estimation using a word-count heuristic.
229
+
230
+ LLM tokenizers typically produce ~1.3 tokens per whitespace-delimited word
231
+ for English. Arabic and mixed-script text can be slightly higher, but 1.3
232
+ is a safe, conservative multiplier.
233
+ """
234
+ word_count = len(text.split())
235
+ return int(word_count * 1.3)
236
+
237
+
238
+ def _split_into_chunks(text: str, target_tokens: int = _CHUNK_TARGET_TOKENS) -> List[str]:
239
+ """
240
+ Split text into chunks of approximately `target_tokens` tokens each.
241
+
242
+ Splits on sentence boundaries (period + space, newline) to avoid
243
+ cutting mid-sentence. Falls back to word-level splitting if no
244
+ sentence boundaries are found within a chunk.
245
+ """
246
+ # Split into sentences (on ". " or newline)
247
+ sentences = re.split(r'(?<=[.!?])\s+|\n+', text)
248
+ sentences = [s.strip() for s in sentences if s.strip()]
249
+
250
+ chunks: List[str] = []
251
+ current_chunk: List[str] = []
252
+ current_tokens = 0
253
+
254
+ for sentence in sentences:
255
+ sentence_tokens = _estimate_tokens(sentence)
256
+
257
+ # If a single sentence exceeds the target, split by words
258
+ if sentence_tokens > target_tokens:
259
+ # Flush current chunk first
260
+ if current_chunk:
261
+ chunks.append(" ".join(current_chunk))
262
+ current_chunk = []
263
+ current_tokens = 0
264
+
265
+ words = sentence.split()
266
+ word_buffer: List[str] = []
267
+ buffer_tokens = 0
268
+ for word in words:
269
+ wt = _estimate_tokens(word)
270
+ if buffer_tokens + wt > target_tokens and word_buffer:
271
+ chunks.append(" ".join(word_buffer))
272
+ word_buffer = [word]
273
+ buffer_tokens = wt
274
+ else:
275
+ word_buffer.append(word)
276
+ buffer_tokens += wt
277
+ if word_buffer:
278
+ chunks.append(" ".join(word_buffer))
279
+ continue
280
+
281
+ if current_tokens + sentence_tokens > target_tokens and current_chunk:
282
+ chunks.append(" ".join(current_chunk))
283
+ current_chunk = [sentence]
284
+ current_tokens = sentence_tokens
285
+ else:
286
+ current_chunk.append(sentence)
287
+ current_tokens += sentence_tokens
288
+
289
+ # Don't forget the last chunk
290
+ if current_chunk:
291
+ chunks.append(" ".join(current_chunk))
292
+
293
+ return chunks
294
+
295
+
296
  # ─────────────────────────────────────────────────────────────────────────────
297
  # NOTE GENERATOR
298
  # ─────────────────────────────────────────────────────────────────────────────
299
 
300
  class NoteGenerator:
301
+ """
302
+ Generates structured study notes using Groq.
303
+
304
+ Automatically selects between:
305
+ - **Single-pass**: for short transcripts (< 8K tokens)
306
+ - **Map-Reduce**: for long transcripts (β‰₯ 8K tokens), splitting into
307
+ chunks, summarizing each with a fast model, then merging with the
308
+ primary model.
309
+ """
310
 
311
  def __init__(self):
312
  self.api_key = os.environ.get("GROQ_API_KEY", "").strip()
313
  self.client = Groq(api_key=self.api_key) if self.api_key else None
314
+ self.model_primary = _MODEL_PRIMARY
315
+ self.model_map = _MODEL_MAP
316
+ self.chunk_delay = float(
317
+ os.environ.get("GROQ_CHUNK_DELAY_SECONDS", "3")
318
+ )
319
+ logger.info(
320
+ "πŸš€ NoteGenerator v5.0 initialized β€” primary: %s, map: %s, delay: %.1fs",
321
+ self.model_primary, self.model_map, self.chunk_delay,
322
+ )
323
+
324
+ # ── Low-level API call ──────────────────────────────────────────────
325
 
326
+ def _chat(
327
+ self,
328
+ system: str,
329
+ user: str,
330
+ model: Optional[str] = None,
331
+ max_tokens: int = 4096,
332
+ ) -> Optional[str]:
333
+ """Send a chat completion request to Groq."""
334
+ model = model or self.model_primary
335
  try:
336
  response = self.client.chat.completions.create(
337
+ model=model,
338
  max_tokens=max_tokens,
339
  temperature=0.3,
340
  response_format={"type": "json_object"},
 
345
  )
346
  return response.choices[0].message.content
347
  except Exception as e:
348
+ logger.error("❌ Groq API call failed (model=%s): %s", model, e)
349
  return None
350
 
351
+ # ── Error fallback ──────────────────────────────────────────────────
352
+
353
  def _get_error_json(self, error_msg: str) -> Dict:
354
  return {
355
  "title": "Error in Generation",
 
360
  "topics": [],
361
  }
362
 
363
+ # ── Single-pass summarization (short transcripts) ───────────────────
364
+
365
+ def _single_pass(self, transcript_text: str, video_title: str) -> Dict:
366
+ """Process the entire transcript in one API call."""
367
+ logger.info("πŸ“ Single-pass summarization via %s", self.model_primary)
368
 
 
369
  user_prompt = _SUMMARY_USER.format(
370
  video_title=video_title,
371
+ transcript=transcript_text,
372
  )
373
 
374
  raw = self._chat(_SUMMARY_SYSTEM, user_prompt, max_tokens=4096)
375
  if raw is None:
376
+ return self._get_error_json("Groq API call failed (single-pass).")
377
+
378
+ return self._parse_and_validate(raw)
379
+
380
+ # ── Map-Reduce summarization (long transcripts) ─────────────────────
381
+
382
+ def _map_reduce(self, transcript_text: str, video_title: str) -> Dict:
383
+ """
384
+ Split transcript into chunks, summarize each (MAP), then merge (REDUCE).
385
+ """
386
+ chunks = _split_into_chunks(transcript_text)
387
+ total = len(chunks)
388
+ logger.info(
389
+ "πŸ—ΊοΈ Map-Reduce activated: %d chunks (delay=%.1fs between calls)",
390
+ total, self.chunk_delay,
391
+ )
392
 
393
+ # ── MAP PHASE ───────────────────────────────────────────────────
394
+ intermediate_results: List[Dict] = []
395
+
396
+ for i, chunk in enumerate(chunks, start=1):
397
+ chunk_tokens = _estimate_tokens(chunk)
398
+ logger.info(
399
+ " πŸ“¦ MAP chunk %d/%d (~%d tokens)...", i, total, chunk_tokens,
400
+ )
401
+
402
+ user_prompt = _MAP_USER.format(
403
+ video_title=video_title,
404
+ chunk_index=i,
405
+ total_chunks=total,
406
+ chunk_text=chunk,
407
+ )
408
+
409
+ raw = self._chat(
410
+ _MAP_SYSTEM, user_prompt,
411
+ model=self.model_map,
412
+ max_tokens=2048,
413
+ )
414
+
415
+ if raw:
416
+ try:
417
+ parsed = json.loads(raw)
418
+ intermediate_results.append(parsed)
419
+ logger.info(" βœ… MAP chunk %d/%d done.", i, total)
420
+ except json.JSONDecodeError as e:
421
+ logger.warning(
422
+ " ⚠️ MAP chunk %d/%d returned invalid JSON: %s", i, total, e,
423
+ )
424
+ else:
425
+ logger.warning(" ⚠️ MAP chunk %d/%d returned no response.", i, total)
426
+
427
+ # Respect TPM limits β€” delay between consecutive API calls
428
+ if i < total and self.chunk_delay > 0:
429
+ logger.info(" ⏳ Sleeping %.1fs (TPM cooldown)...", self.chunk_delay)
430
+ time.sleep(self.chunk_delay)
431
+
432
+ if not intermediate_results:
433
+ return self._get_error_json(
434
+ "Map-Reduce failed: no chunks were successfully summarized."
435
+ )
436
+
437
+ # ── REDUCE PHASE ────────────────────────────────────────────────
438
+ logger.info("πŸ”— REDUCE phase: merging %d intermediate summaries...", len(intermediate_results))
439
+
440
+ # Build a readable merged text for the reduce prompt
441
+ merged_parts: List[str] = []
442
+ all_topics: List[str] = []
443
+ detected_lang = "English"
444
+
445
+ for idx, result in enumerate(intermediate_results, start=1):
446
+ detected_lang = result.get("detected_language", detected_lang)
447
+ chunk_summary = result.get("chunk_summary", "")
448
+ key_points = result.get("key_points", [])
449
+ topics = result.get("topics", [])
450
+ all_topics.extend(topics)
451
+
452
+ part = f"--- Chunk {idx} ---\n"
453
+ part += f"Summary: {chunk_summary}\n"
454
+ for kp in key_points:
455
+ if isinstance(kp, dict):
456
+ part += f"- {kp.get('title', '')}: {kp.get('detail', '')} "
457
+ part += f"(Insight: {kp.get('insight', '')})\n"
458
+ part += f"Topics: {', '.join(topics)}\n"
459
+ merged_parts.append(part)
460
+
461
+ merged_text = "\n".join(merged_parts)
462
+
463
+ # Check if the merged text itself is within single-pass limits
464
+ reduce_tokens = _estimate_tokens(merged_text)
465
+ logger.info("πŸ”— REDUCE input: ~%d tokens", reduce_tokens)
466
+
467
+ user_prompt = _REDUCE_USER.format(
468
+ video_title=video_title,
469
+ total_chunks=len(intermediate_results),
470
+ merged_summaries=merged_text,
471
+ )
472
+
473
+ # REDUCE uses the primary (high-quality) model
474
+ # Sleep before REDUCE to ensure TPM cooldown from last MAP call
475
+ if self.chunk_delay > 0:
476
+ logger.info(" ⏳ Sleeping %.1fs before REDUCE call...", self.chunk_delay)
477
+ time.sleep(self.chunk_delay)
478
+
479
+ raw = self._chat(
480
+ _REDUCE_SYSTEM, user_prompt,
481
+ model=self.model_primary,
482
+ max_tokens=4096,
483
+ )
484
+
485
+ if raw is None:
486
+ return self._get_error_json("Groq API call failed (REDUCE phase).")
487
+
488
+ return self._parse_and_validate(raw)
489
+
490
+ # ── JSON parsing + schema validation ────────────────────────────────
491
+
492
+ def _parse_and_validate(self, raw_json: str) -> Dict:
493
+ """Parse raw JSON string and validate against SummarySchema."""
494
  try:
495
+ data = json.loads(raw_json)
496
  validated = SummarySchema(**data)
497
  return validated.model_dump()
498
  except (json.JSONDecodeError, ValidationError) as e:
499
+ logger.error("❌ Schema validation failed: %s", e)
500
  return self._get_error_json(f"Validation Error: {str(e)}")
501
 
502
+ # ── Public API (unchanged signature) ────────────────────────────────
503
+
504
+ def generateSummary(self, transcript_text: str, video_title: str) -> Dict:
505
+ """
506
+ Generate structured JSON summary from transcript.
507
+
508
+ Automatically selects single-pass or Map-Reduce based on estimated
509
+ token count. The return type is always a Dict matching SummarySchema.
510
+ """
511
+ if not self.client:
512
+ return self._get_error_json("Groq API Key missing.")
513
+
514
+ # Estimate total tokens for the full prompt
515
+ full_prompt = _SUMMARY_USER.format(
516
+ video_title=video_title,
517
+ transcript=transcript_text,
518
+ )
519
+ total_tokens = _estimate_tokens(_SUMMARY_SYSTEM + full_prompt)
520
+
521
+ logger.info(
522
+ "πŸ“Š Token estimate: ~%d tokens (threshold: %d)",
523
+ total_tokens, _SINGLE_PASS_TOKEN_LIMIT,
524
+ )
525
+
526
+ if total_tokens < _SINGLE_PASS_TOKEN_LIMIT:
527
+ return self._single_pass(transcript_text, video_title)
528
+ else:
529
+ logger.info(
530
+ "⚑ Transcript too large for single-pass (%d β‰₯ %d). "
531
+ "Activating Map-Reduce pipeline...",
532
+ total_tokens, _SINGLE_PASS_TOKEN_LIMIT,
533
+ )
534
+ return self._map_reduce(transcript_text, video_title)
535
+
536
+ # ── Markdown formatting (unchanged) ─────────────────────────────────
537
+
538
  def format_notes_to_markdown(self, json_notes: Dict) -> str:
539
  """Convert JSON notes to clean Markdown β€” Summary β†’ Timeline β†’ Conclusion."""
540
  lang = json_notes.get("detected_language", "English")