ZedLow commited on
Commit
d9a3b5a
·
verified ·
1 Parent(s): 28d7a26

Update rag/data.py

Browse files
Files changed (1) hide show
  1. rag/data.py +12 -54
rag/data.py CHANGED
@@ -1,60 +1,18 @@
1
  import json
2
- from dataclasses import dataclass
3
- from pathlib import Path
4
- from typing import List, Optional
5
-
6
  from rag.logging_utils import get_logger
7
 
8
  logger = get_logger(__name__)
9
 
10
- @dataclass
11
- class Doc:
12
- doc_name: str
13
- image_path: str
14
- text: str
15
- company: str # "Apple" | "Microsoft" | "Unknown"
16
-
17
- def infer_company(doc_name: str) -> str:
18
- dn = (doc_name or "").lower()
19
- if "microsoft" in dn or "msft" in dn:
20
- return "Microsoft"
21
- if "apple" in dn or "aapl" in dn:
22
- return "Apple"
23
- return "Unknown"
24
-
25
- def load_dataset(dataset_path: str) -> List[Doc]:
26
- path = Path(dataset_path)
27
- if not path.exists():
28
- logger.warning("Dataset file not found: %s", dataset_path)
29
  return []
30
-
31
- with path.open("r", encoding="utf-8") as f:
32
- raw = json.load(f)
33
-
34
- docs: List[Doc] = []
35
- for item in raw:
36
- doc_name = item.get("doc_name", "Unknown Document")
37
- image_path = item.get("image_path", "")
38
- text = (item.get("text") or "").strip()
39
- company = item.get("company") or infer_company(doc_name)
40
-
41
- if not text:
42
- continue
43
-
44
- docs.append(
45
- Doc(
46
- doc_name=doc_name,
47
- image_path=image_path,
48
- text=text,
49
- company=company,
50
- )
51
- )
52
-
53
- logger.info("Loaded %d docs", len(docs))
54
- return docs
55
-
56
- def filter_docs_by_companies(docs: List[Doc], companies: Optional[List[str]]) -> List[Doc]:
57
- if not companies:
58
- return docs
59
- allowed = set(companies)
60
- return [d for d in docs if d.company in allowed]
 
1
  import json
2
+ from typing import List, Dict, Any
 
 
 
3
  from rag.logging_utils import get_logger
4
 
5
  logger = get_logger(__name__)
6
 
7
+ def load_dataset(path: str) -> List[Dict[str, Any]]:
8
+ try:
9
+ with open(path, "r", encoding="utf-8") as f:
10
+ data = json.load(f)
11
+ if not isinstance(data, list):
12
+ logger.warning("Dataset JSON is not a list. Found: %s", type(data))
13
+ return []
14
+ logger.info("Loaded dataset: %d docs", len(data))
15
+ return data
16
+ except Exception as e:
17
+ logger.warning("⚠️ Dataset not found/invalid (%s): %s", path, e)
 
 
 
 
 
 
 
 
18
  return []