Ara Yeroyan commited on
Commit
fab49c5
·
1 Parent(s): caeff10
Files changed (1) hide show
  1. utils.py +163 -0
utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import dataclasses
3
+ from uuid import UUID
4
+ from typing import Any
5
+ from datetime import datetime, date
6
+
7
+
8
+ import configparser
9
+ from torch import cuda
10
+ from qdrant_client.http import models as rest
11
+ from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
13
+
14
+
15
+ def get_config(fp):
16
+ config = configparser.ConfigParser()
17
+ config.read_file(open(fp))
18
+ return config
19
+
20
+
21
+ def get_embeddings_model(config):
22
+ device = "cuda" if cuda.is_available() else "cpu"
23
+
24
+ # Define embedding model
25
+ model_name = config.get("retriever", "MODEL")
26
+ model_kwargs = {"device": device}
27
+ normalize_embeddings = bool(int(config.get("retriever", "NORMALIZE")))
28
+ encode_kwargs = {
29
+ "normalize_embeddings": normalize_embeddings,
30
+ "batch_size": 100,
31
+ }
32
+
33
+ embeddings = HuggingFaceEmbeddings(
34
+ show_progress=True,
35
+ model_name=model_name,
36
+ model_kwargs=model_kwargs,
37
+ encode_kwargs=encode_kwargs,
38
+ )
39
+
40
+ return embeddings
41
+
42
+ # Create a search filter for Qdrant
43
+ def create_filter(
44
+ reports: list = [], sources: str = None, subtype: str = None, year: str = None
45
+ ):
46
+ if len(reports) == 0:
47
+ print(f"defining filter for sources:{sources}, subtype:{subtype}")
48
+ filter = rest.Filter(
49
+ must=[
50
+ rest.FieldCondition(
51
+ key="metadata.source", match=rest.MatchValue(value=sources)
52
+ ),
53
+ rest.FieldCondition(
54
+ key="metadata.filename", match=rest.MatchAny(any=subtype)
55
+ ),
56
+ # rest.FieldCondition(
57
+ # key="metadata.year",
58
+ # match=rest.MatchAny(any=year)
59
+ ]
60
+ )
61
+ else:
62
+ print(f"defining filter for allreports:{reports}")
63
+ filter = rest.Filter(
64
+ must=[
65
+ rest.FieldCondition(
66
+ key="metadata.filename", match=rest.MatchAny(any=reports)
67
+ )
68
+ ]
69
+ )
70
+
71
+ return filter
72
+
73
+
74
+ def load_json(fp):
75
+ with open(fp, "r") as f:
76
+ docs = json.load(f)
77
+ return docs
78
+
79
+ def get_timestamp():
80
+ now = datetime.datetime.now()
81
+ timestamp = now.strftime("%Y%m%d%H%M%S")
82
+ return timestamp
83
+
84
+
85
+
86
+ # A custom class to help with recursive serialization.
87
+ # This approach avoids modifying the original object.
88
+ class _RecursiveSerializer(json.JSONEncoder):
89
+ """A custom JSONEncoder that handles complex types by converting them to dicts or strings."""
90
+ def default(self, obj):
91
+ # Prefer the pydantic method if it exists for the most robust serialization.
92
+ if hasattr(obj, 'model_dump'):
93
+ return obj.model_dump()
94
+
95
+ # Handle dataclasses
96
+ if dataclasses.is_dataclass(obj):
97
+ return dataclasses.asdict(obj)
98
+
99
+ # Handle other non-serializable but common types.
100
+ if isinstance(obj, (datetime, date, UUID)):
101
+ return str(obj)
102
+
103
+ # Fallback for general objects with a __dict__
104
+ if hasattr(obj, '__dict__'):
105
+ return obj.__dict__
106
+
107
+ # Default fallback to JSONEncoder's behavior
108
+ return super().default(obj)
109
+
110
+ def to_json_string(obj: Any, **kwargs) -> str:
111
+ """
112
+ Serializes a Python object into a JSON-formatted string.
113
+
114
+ This function is a comprehensive utility that can handle:
115
+ - Standard Python types (lists, dicts, strings, numbers, bools, None).
116
+ - Pydantic models (using `model_dump()`).
117
+ - Dataclasses (using `dataclasses.asdict()`).
118
+ - Standard library types not natively JSON-serializable (e.g., datetime, UUID).
119
+ - Custom classes with a `__dict__`.
120
+
121
+ Args:
122
+ obj (Any): The Python object to serialize.
123
+ **kwargs: Additional keyword arguments to pass to `json.dumps`.
124
+
125
+ Returns:
126
+ str: A JSON-formatted string.
127
+
128
+ Example:
129
+ >>> from datetime import datetime
130
+ >>> from pydantic import BaseModel
131
+ >>> from dataclasses import dataclass
132
+
133
+ >>> class Address(BaseModel):
134
+ ... street: str
135
+ ... city: str
136
+
137
+ >>> @dataclass
138
+ ... class Product:
139
+ ... id: int
140
+ ... name: str
141
+
142
+ >>> class Order(BaseModel):
143
+ ... user_address: Address
144
+ ... item: Product
145
+
146
+ >>> order_obj = Order(
147
+ ... user_address=Address(street="123 Main St", city="Example City"),
148
+ ... item=Product(id=1, name="Laptop")
149
+ ... )
150
+
151
+ >>> print(to_json_string(order_obj, indent=2))
152
+ {
153
+ "user_address": {
154
+ "street": "123 Main St",
155
+ "city": "Example City"
156
+ },
157
+ "item": {
158
+ "id": 1,
159
+ "name": "Laptop"
160
+ }
161
+ }
162
+ """
163
+ return json.dumps(obj, cls=_RecursiveSerializer, **kwargs)