SarahXia0405 commited on
Commit
e60e9dd
·
verified ·
1 Parent(s): 0a5c6d4

Update api/rag_engine.py

Browse files
Files changed (1) hide show
  1. api/rag_engine.py +203 -1
api/rag_engine.py CHANGED
@@ -188,6 +188,38 @@ def _parse_pptx_to_text(path: str) -> List[Tuple[str, str]]:
188
  out.append((f"slide{idx}", _clean_text("\n".join(lines))))
189
  return out
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  # ----------------------------
193
  # Public API
@@ -211,9 +243,11 @@ def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]:
211
  sections = _parse_docx_to_text(path)
212
  elif ext == ".pptx":
213
  sections = _parse_pptx_to_text(path)
214
- elif ext in [".txt", ".md"]:
215
  with open(path, "r", encoding="utf-8", errors="ignore") as f:
216
  sections = [("text", _clean_text(f.read()))]
 
 
217
  else:
218
  print(f"[rag_engine] unsupported file type: {ext}")
219
  return []
@@ -361,3 +395,171 @@ def retrieve_relevant_chunks(
361
 
362
  context = "\n\n---\n\n".join(truncated_texts)
363
  return context, used
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  out.append((f"slide{idx}", _clean_text("\n".join(lines))))
189
  return out
190
 
191
+ import json
192
+
193
+ def _parse_ipynb_to_text(path: str) -> List[Tuple[str, str]]:
194
+ try:
195
+ with open(path, "r", encoding="utf-8", errors="ignore") as f:
196
+ nb = json.load(f)
197
+ except Exception:
198
+ return []
199
+
200
+ cells = nb.get("cells", []) or []
201
+ parts: List[str] = []
202
+ for c in cells:
203
+ ctype = c.get("cell_type", "")
204
+ src = c.get("source", [])
205
+ if isinstance(src, list):
206
+ src = "".join(src)
207
+ else:
208
+ src = str(src or "")
209
+ src = src.strip()
210
+ if not src:
211
+ continue
212
+
213
+ if ctype == "markdown":
214
+ parts.append(src)
215
+ elif ctype == "code":
216
+ # 保留代码(对 Lab 很重要)
217
+ parts.append("```python\n" + src + "\n```")
218
+ else:
219
+ parts.append(src)
220
+
221
+ full = _clean_text("\n\n".join(parts))
222
+ return [("ipynb", full)] if full else []
223
 
224
  # ----------------------------
225
  # Public API
 
243
  sections = _parse_docx_to_text(path)
244
  elif ext == ".pptx":
245
  sections = _parse_pptx_to_text(path)
246
+ elif ext in [".txt", ".md", ".py"]:
247
  with open(path, "r", encoding="utf-8", errors="ignore") as f:
248
  sections = [("text", _clean_text(f.read()))]
249
+ elif ext == ".ipynb":
250
+ sections = _parse_ipynb_to_text(path)
251
  else:
252
  print(f"[rag_engine] unsupported file type: {ext}")
253
  return []
 
395
 
396
  context = "\n\n---\n\n".join(truncated_texts)
397
  return context, used
398
+
399
+ # ============================
400
+ # Course-scoped Vector Index (Simple: chunks.json + embeddings.npy)
401
+ # ============================
402
+ import json
403
+ from typing import Any
404
+ import numpy as np
405
+
406
+ from api.config import client, EMBEDDING_MODEL # 你 config.py 里有 client
407
+
408
+ def _course_root(course_id: str) -> str:
409
+ return os.path.join("data", "courses", course_id)
410
+
411
+ def _course_raw_dir(course_id: str) -> str:
412
+ return os.path.join(_course_root(course_id), "raw")
413
+
414
+ def _course_index_dir(course_id: str) -> str:
415
+ return os.path.join(_course_root(course_id), "index")
416
+
417
+ def _course_chunks_path(course_id: str) -> str:
418
+ return os.path.join(_course_index_dir(course_id), "chunks.json")
419
+
420
+ def _course_emb_path(course_id: str) -> str:
421
+ return os.path.join(_course_index_dir(course_id), "embeddings.npy")
422
+
423
+ def ensure_course_dirs(course_id: str) -> None:
424
+ os.makedirs(_course_raw_dir(course_id), exist_ok=True)
425
+ os.makedirs(_course_index_dir(course_id), exist_ok=True)
426
+
427
+ def _embed_texts(texts: List[str]) -> np.ndarray:
428
+ # batched embeddings
429
+ resp = client.embeddings.create(model=EMBEDDING_MODEL, input=texts)
430
+ vecs = [d.embedding for d in resp.data]
431
+ return np.array(vecs, dtype=np.float32)
432
+
433
+ def load_course_index(course_id: str) -> Tuple[List[Dict[str, Any]], Optional[np.ndarray]]:
434
+ ensure_course_dirs(course_id)
435
+ cp = _course_chunks_path(course_id)
436
+ ep = _course_emb_path(course_id)
437
+
438
+ if not os.path.exists(cp) or not os.path.exists(ep):
439
+ return [], None
440
+
441
+ try:
442
+ with open(cp, "r", encoding="utf-8") as f:
443
+ chunks = json.load(f)
444
+ embs = np.load(ep)
445
+ if len(chunks) != embs.shape[0]:
446
+ return [], None
447
+ return chunks, embs
448
+ except Exception:
449
+ return [], None
450
+
451
+ def save_course_index(course_id: str, chunks: List[Dict[str, Any]], embs: np.ndarray) -> None:
452
+ ensure_course_dirs(course_id)
453
+ with open(_course_chunks_path(course_id), "w", encoding="utf-8") as f:
454
+ json.dump(chunks, f, ensure_ascii=False, indent=2)
455
+ np.save(_course_emb_path(course_id), embs)
456
+
457
+ def add_file_to_course_index(course_id: str, file_path: str, doc_type: str) -> Dict[str, Any]:
458
+ """
459
+ Parse -> chunk -> embed -> append -> save
460
+ """
461
+ ensure_course_dirs(course_id)
462
+
463
+ new_chunks = build_rag_chunks_from_file(file_path, doc_type) or []
464
+ texts = [c.get("text", "") for c in new_chunks if c.get("text")]
465
+ if not texts:
466
+ return {"added_chunks": 0, "total_chunks": 0}
467
+
468
+ new_embs = _embed_texts(texts)
469
+
470
+ chunks, embs = load_course_index(course_id)
471
+ if embs is None:
472
+ chunks = []
473
+ embs = np.zeros((0, new_embs.shape[1]), dtype=np.float32)
474
+
475
+ chunks.extend(new_chunks)
476
+ embs = np.vstack([embs, new_embs])
477
+
478
+ save_course_index(course_id, chunks, embs)
479
+ return {"added_chunks": len(new_chunks), "total_chunks": len(chunks)}
480
+
481
+ def _cosine_topk(query_vec: np.ndarray, mat: np.ndarray, k: int) -> List[int]:
482
+ q = query_vec / (np.linalg.norm(query_vec) + 1e-8)
483
+ m = mat / (np.linalg.norm(mat, axis=1, keepdims=True) + 1e-8)
484
+ sims = m @ q
485
+ k = max(1, min(int(k), sims.shape[0]))
486
+ idx = np.argpartition(-sims, kth=k-1)[:k]
487
+ idx = idx[np.argsort(-sims[idx])]
488
+ return idx.tolist()
489
+
490
+ def retrieve_relevant_chunks_vector(
491
+ query: str,
492
+ course_id: str,
493
+ k: int = RAG_TOPK_LIMIT,
494
+ chunk_token_limit: int = RAG_CHUNK_TOKEN_LIMIT,
495
+ max_context_tokens: int = RAG_CONTEXT_TOKEN_LIMIT,
496
+ model_for_tokenizer: str = "",
497
+ allowed_source_files: Optional[List[str]] = None,
498
+ allowed_doc_types: Optional[List[str]] = None,
499
+ ) -> Tuple[str, List[Dict]]:
500
+ """
501
+ Vector retrieval scoped to course_id, with the same scoping semantics you already use.
502
+ """
503
+ query = _clean_text(query)
504
+ if not query:
505
+ return "", []
506
+
507
+ chunks, embs = load_course_index(course_id)
508
+ if not chunks or embs is None or embs.shape[0] == 0:
509
+ return "", []
510
+
511
+ # scope BEFORE similarity
512
+ keep = list(range(len(chunks)))
513
+
514
+ if allowed_source_files:
515
+ allow_files = {_basename(str(x)).strip() for x in allowed_source_files if str(x).strip()}
516
+ if allow_files:
517
+ keep = [i for i in keep if _basename(str(chunks[i].get("source_file", ""))).strip() in allow_files]
518
+
519
+ if allowed_doc_types:
520
+ allow_dt = {str(x).strip() for x in allowed_doc_types if str(x).strip()}
521
+ if allow_dt:
522
+ keep = [i for i in keep if str(chunks[i].get("doc_type", "")).strip() in allow_dt]
523
+
524
+ if not keep:
525
+ return "", []
526
+
527
+ cand_embs = embs[keep]
528
+
529
+ qv = _embed_texts([query])[0]
530
+ top_local = _cosine_topk(qv, cand_embs, k=min(k, RAG_TOPK_LIMIT))
531
+ top_global = [keep[i] for i in top_local]
532
+ used = [chunks[i] for i in top_global]
533
+
534
+ # truncate like your current logic (token caps)
535
+ used_out: List[Dict] = []
536
+ texts_out: List[str] = []
537
+ total_tokens = 0
538
+
539
+ for c in used:
540
+ raw = c.get("text") or ""
541
+ if not raw:
542
+ continue
543
+ t = _truncate_to_tokens(raw, max_tokens=chunk_token_limit, model=model_for_tokenizer)
544
+ t_tokens = _count_text_tokens(t, model=model_for_tokenizer)
545
+
546
+ if total_tokens + t_tokens > max_context_tokens:
547
+ remaining = max_context_tokens - total_tokens
548
+ if remaining <= 0:
549
+ break
550
+ t = _truncate_to_tokens(t, max_tokens=remaining, model=model_for_tokenizer)
551
+ t_tokens = _count_text_tokens(t, model=model_for_tokenizer)
552
+
553
+ t = _clean_text(t)
554
+ if not t:
555
+ continue
556
+
557
+ texts_out.append(t)
558
+ used_out.append(c)
559
+ total_tokens += t_tokens
560
+
561
+ if total_tokens >= max_context_tokens:
562
+ break
563
+
564
+ context = "\n\n---\n\n".join(texts_out)
565
+ return context, used_out