ZedLow commited on
Commit
82f0b9f
·
verified ·
1 Parent(s): 59fdc20

Create data.py

Browse files
Files changed (1) hide show
  1. rag/data.py +60 -0
rag/data.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import List, Optional
5
+
6
+ from rag.logging_utils import get_logger
7
+
8
+ logger = get_logger(__name__)
9
+
10
+ @dataclass
11
+ class Doc:
12
+ doc_name: str
13
+ image_path: str
14
+ text: str
15
+ company: str # "Apple" | "Microsoft" | "Unknown"
16
+
17
+ def infer_company(doc_name: str) -> str:
18
+ dn = (doc_name or "").lower()
19
+ if "microsoft" in dn or "msft" in dn:
20
+ return "Microsoft"
21
+ if "apple" in dn or "aapl" in dn:
22
+ return "Apple"
23
+ return "Unknown"
24
+
25
+ def load_dataset(dataset_path: str) -> List[Doc]:
26
+ path = Path(dataset_path)
27
+ if not path.exists():
28
+ logger.warning("Dataset file not found: %s", dataset_path)
29
+ return []
30
+
31
+ with path.open("r", encoding="utf-8") as f:
32
+ raw = json.load(f)
33
+
34
+ docs: List[Doc] = []
35
+ for item in raw:
36
+ doc_name = item.get("doc_name", "Unknown Document")
37
+ image_path = item.get("image_path", "")
38
+ text = (item.get("text") or "").strip()
39
+ company = item.get("company") or infer_company(doc_name)
40
+
41
+ if not text:
42
+ continue
43
+
44
+ docs.append(
45
+ Doc(
46
+ doc_name=doc_name,
47
+ image_path=image_path,
48
+ text=text,
49
+ company=company,
50
+ )
51
+ )
52
+
53
+ logger.info("Loaded %d docs", len(docs))
54
+ return docs
55
+
56
+ def filter_docs_by_companies(docs: List[Doc], companies: Optional[List[str]]) -> List[Doc]:
57
+ if not companies:
58
+ return docs
59
+ allowed = set(companies)
60
+ return [d for d in docs if d.company in allowed]