fengxb30 commited on
Commit
a8c077f
·
verified ·
1 Parent(s): de1ce6a

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +224 -224
inference.py CHANGED
@@ -1,225 +1,225 @@
1
- # ================================================================
2
- # File: inference.py
3
- # Description:
4
- # Inference script for FinGPT Task II (Compliance Agents)
5
- # using Hugging Face model "Fin-01-8B" and local XBRL knowledge base.
6
- # ================================================================
7
-
8
- import os
9
- import re
10
- import json
11
- import torch
12
- from transformers import AutoTokenizer, AutoModelForCausalLM
13
-
14
-
15
- # ================================================================
16
- # 1️⃣ Load the Hugging Face Model (Fin-01-8B)
17
- # ================================================================
18
- def load_model(model_name_or_path="fengxb30/Fin-01-8B"):
19
- """
20
- Loads the tokenizer and causal LM from Hugging Face Hub (Fin-01-8B).
21
- Automatically sets device, dtype, and pad_token.
22
- """
23
- print(f"🔹 Loading model from Hugging Face: '{model_name_or_path}'...")
24
-
25
- try:
26
- tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
27
- except Exception as e:
28
- raise RuntimeError(f"❌ Failed to load tokenizer: {e}")
29
-
30
- # Ensure pad_token exists
31
- if tokenizer.pad_token_id is None:
32
- tokenizer.pad_token = tokenizer.eos_token or "[PAD]"
33
-
34
- try:
35
- model = AutoModelForCausalLM.from_pretrained(
36
- model_name_or_path,
37
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
38
- device_map="auto",
39
- low_cpu_mem_usage=True
40
- )
41
- except Exception as e:
42
- raise RuntimeError(f"❌ Failed to load model weights: {e}")
43
-
44
- model.eval()
45
- print(f"✅ Model '{model_name_or_path}' loaded successfully.")
46
- return tokenizer, model
47
-
48
-
49
- # ================================================================
50
- # 2️⃣ Load Local XBRL Knowledge Base
51
- # ================================================================
52
- def load_knowledge_base(kb_path="xbrl_results_2_spec_filtered_reindexed.json"):
53
- """
54
- Loads local JSON knowledge base for Retrieval-Augmented Generation.
55
- """
56
- print("🔹 Loading local XBRL knowledge base...")
57
- if not os.path.exists(kb_path):
58
- raise FileNotFoundError(f"❌ Knowledge base not found at '{kb_path}'.")
59
- with open(kb_path, "r", encoding="utf-8") as f:
60
- kb = json.load(f)
61
- if not isinstance(kb, list):
62
- raise ValueError("❌ Knowledge base JSON must be a list of documents.")
63
- print(f"✅ Knowledge base loaded successfully with {len(kb)} entries.")
64
- return kb
65
-
66
-
67
- # ================================================================
68
- # 3️⃣ New Tool: Retrieval from Local XBRL Knowledge Base
69
- # ================================================================
70
- def _tokenize(text: str):
71
- """Lightweight tokenizer for keyword retrieval."""
72
- return re.findall(r"\w+", text.lower())
73
-
74
-
75
- def retrieve_from_xbrl_database(query: str, kb: list, top_k: int = 2, max_chars: int = 1500) -> str:
76
- """
77
- Retrieves top-k relevant context snippets from the local XBRL KB.
78
- Uses a simple keyword-matching retrieval algorithm.
79
- """
80
- if not kb:
81
- return ""
82
-
83
- query_words = set(_tokenize(query))
84
- scores = []
85
-
86
- for doc in kb:
87
- title = doc.get("title", "")
88
- text = doc.get("text", "")
89
- title_words = set(_tokenize(title))
90
- text_words = set(_tokenize(text))
91
- score = len(query_words & title_words) * 3 + len(query_words & text_words)
92
- if score > 0:
93
- scores.append((score, doc))
94
-
95
- if not scores:
96
- return ""
97
-
98
- # Sort documents by score in descending order
99
- scores.sort(key=lambda x: x[0], reverse=True)
100
- top_docs = [d for _, d in scores[:top_k]]
101
-
102
- # Format the top_k results as context
103
- context = ""
104
- for doc in top_docs:
105
- snippet = (doc.get("text") or "")[:max_chars]
106
- context += (
107
- f"Source: {doc.get('url', 'N/A')}\n"
108
- f"Title: {doc.get('title', 'Untitled')}\n\n"
109
- f"Snippet: {snippet}\n\n"
110
- "---\n\n"
111
- )
112
-
113
- return context.strip()
114
-
115
-
116
- # ================================================================
117
- # 4️⃣ Model Inference with Context (RAG)
118
- # ================================================================
119
- def generate_response(
120
- model,
121
- tokenizer,
122
- prompt: str,
123
- context: str = None,
124
- temperature: float = 0.2,
125
- max_new_tokens: int = 512,
126
- ) -> str:
127
- """
128
- Generates a response using Fin-01-8B model given prompt and optional context.
129
- """
130
- if context:
131
- full_input = (
132
- "Based on the following context from the XBRL specifications, "
133
- "please answer the question.\n\n"
134
- f"[Context]\n{context}\n\n"
135
- f"[Question]\n{prompt}\n\n"
136
- "[Answer]\n"
137
- )
138
- else:
139
- full_input = f"Question: {prompt}\nAnswer:\n"
140
-
141
- inputs = tokenizer(
142
- full_input,
143
- return_tensors="pt",
144
- truncation=True,
145
- max_length=tokenizer.model_max_length - max_new_tokens
146
- ).to(model.device)
147
-
148
- pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
149
-
150
- with torch.no_grad():
151
- outputs = model.generate(
152
- **inputs,
153
- max_new_tokens=max_new_tokens,
154
- temperature=temperature,
155
- top_p=0.9,
156
- do_sample=True,
157
- pad_token_id=pad_token_id,
158
- eos_token_id=tokenizer.eos_token_id
159
- )
160
-
161
- input_len = inputs["input_ids"].shape[1]
162
- new_tokens = outputs[0][input_len:]
163
- response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
164
- return response
165
-
166
-
167
- # ================================================================
168
- # 5️⃣ The RAG Inference Pipeline
169
- # ================================================================
170
- def xbrl_compliance_agent(query: str, model, tokenizer, kb: list):
171
- """
172
- Full pipeline:
173
- 1. Retrieve context from local XBRL knowledge base.
174
- 2. Generate answer using Fin-01-8B model.
175
- """
176
- print(f"\n🔹 Retrieving context for: '{query}'...")
177
- context = retrieve_from_xbrl_database(query, kb, top_k=2)
178
- if context:
179
- print("✅ Context retrieval complete.")
180
- else:
181
- print("⚠️ No relevant context found.")
182
-
183
- print("🔹 Generating response from Fin-01-8B...")
184
- answer = generate_response(model, tokenizer, query, context)
185
- return answer
186
-
187
-
188
- # -----------------------------
189
- # 6️⃣ Example Run
190
- # -----------------------------
191
- if __name__ == "__main__":
192
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
193
-
194
- # 1️⃣ 加载模型
195
- try:
196
- tokenizer, model = load_model("fengxb30/Fin-01-8B")
197
- except Exception as e:
198
- print(f"❌ 模型加载失败: {e}")
199
- exit(1)
200
-
201
- # 2️⃣ 加载知识库
202
- try:
203
- kb = load_knowledge_base("xbrl_results_2_spec_filtered_reindexed.json")
204
- except Exception as e:
205
- print(f"❌ 知识库加载失败: {e}")
206
- exit(1)
207
-
208
- print("\n🧠 FinGPT Compliance Agent 已启动,输入 'exit' 退出。\n")
209
-
210
- # 3️⃣ 交互问答
211
- while True:
212
- query = input("请输入关于XBRL合规的问题:").strip()
213
- if query.lower() in ["exit", "quit"]:
214
- print("👋 退出程序。")
215
- break
216
- if not query:
217
- continue
218
-
219
- try:
220
- result = xbrl_compliance_agent(query, model, tokenizer, kb)
221
- print("\n=== AI 回复 ===\n")
222
- print(result)
223
- print("\n" + "=" * 40 + "\n")
224
- except Exception as e:
225
  print(f"❌ 推理出错: {e}\n")
 
1
+ # ================================================================
2
+ # File: inference.py
3
+ # Description:
4
+ # Inference script for FinGPT Task II (Compliance Agents)
5
+ # using Hugging Face model "Fin-01-8B" and local XBRL knowledge base.
6
+ # ================================================================
7
+
8
+ import os
9
+ import re
10
+ import json
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+
14
+
15
+ # ================================================================
16
+ # 1️⃣ Load the Hugging Face Model (Fin-01-8B)
17
+ # ================================================================
18
+ def load_model(model_name_or_path="Fin-01-8B"):
19
+ """
20
+ Loads the tokenizer and causal LM from Hugging Face Hub (Fin-01-8B).
21
+ Automatically sets device, dtype, and pad_token.
22
+ """
23
+ print(f"🔹 Loading model from Hugging Face: '{model_name_or_path}'...")
24
+
25
+ try:
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
27
+ except Exception as e:
28
+ raise RuntimeError(f"❌ Failed to load tokenizer: {e}")
29
+
30
+ # Ensure pad_token exists
31
+ if tokenizer.pad_token_id is None:
32
+ tokenizer.pad_token = tokenizer.eos_token or "[PAD]"
33
+
34
+ try:
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_name_or_path,
37
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
38
+ device_map="auto",
39
+ low_cpu_mem_usage=True
40
+ )
41
+ except Exception as e:
42
+ raise RuntimeError(f"❌ Failed to load model weights: {e}")
43
+
44
+ model.eval()
45
+ print(f"✅ Model '{model_name_or_path}' loaded successfully.")
46
+ return tokenizer, model
47
+
48
+
49
+ # ================================================================
50
+ # 2️⃣ Load Local XBRL Knowledge Base
51
+ # ================================================================
52
+ def load_knowledge_base(kb_path="xbrl_results_2_spec_filtered_reindexed.json"):
53
+ """
54
+ Loads local JSON knowledge base for Retrieval-Augmented Generation.
55
+ """
56
+ print("🔹 Loading local XBRL knowledge base...")
57
+ if not os.path.exists(kb_path):
58
+ raise FileNotFoundError(f"❌ Knowledge base not found at '{kb_path}'.")
59
+ with open(kb_path, "r", encoding="utf-8") as f:
60
+ kb = json.load(f)
61
+ if not isinstance(kb, list):
62
+ raise ValueError("❌ Knowledge base JSON must be a list of documents.")
63
+ print(f"✅ Knowledge base loaded successfully with {len(kb)} entries.")
64
+ return kb
65
+
66
+
67
+ # ================================================================
68
+ # 3️⃣ New Tool: Retrieval from Local XBRL Knowledge Base
69
+ # ================================================================
70
+ def _tokenize(text: str):
71
+ """Lightweight tokenizer for keyword retrieval."""
72
+ return re.findall(r"\w+", text.lower())
73
+
74
+
75
+ def retrieve_from_xbrl_database(query: str, kb: list, top_k: int = 2, max_chars: int = 1500) -> str:
76
+ """
77
+ Retrieves top-k relevant context snippets from the local XBRL KB.
78
+ Uses a simple keyword-matching retrieval algorithm.
79
+ """
80
+ if not kb:
81
+ return ""
82
+
83
+ query_words = set(_tokenize(query))
84
+ scores = []
85
+
86
+ for doc in kb:
87
+ title = doc.get("title", "")
88
+ text = doc.get("text", "")
89
+ title_words = set(_tokenize(title))
90
+ text_words = set(_tokenize(text))
91
+ score = len(query_words & title_words) * 3 + len(query_words & text_words)
92
+ if score > 0:
93
+ scores.append((score, doc))
94
+
95
+ if not scores:
96
+ return ""
97
+
98
+ # Sort documents by score in descending order
99
+ scores.sort(key=lambda x: x[0], reverse=True)
100
+ top_docs = [d for _, d in scores[:top_k]]
101
+
102
+ # Format the top_k results as context
103
+ context = ""
104
+ for doc in top_docs:
105
+ snippet = (doc.get("text") or "")[:max_chars]
106
+ context += (
107
+ f"Source: {doc.get('url', 'N/A')}\n"
108
+ f"Title: {doc.get('title', 'Untitled')}\n\n"
109
+ f"Snippet: {snippet}\n\n"
110
+ "---\n\n"
111
+ )
112
+
113
+ return context.strip()
114
+
115
+
116
+ # ================================================================
117
+ # 4️⃣ Model Inference with Context (RAG)
118
+ # ================================================================
119
+ def generate_response(
120
+ model,
121
+ tokenizer,
122
+ prompt: str,
123
+ context: str = None,
124
+ temperature: float = 0.2,
125
+ max_new_tokens: int = 512,
126
+ ) -> str:
127
+ """
128
+ Generates a response using Fin-01-8B model given prompt and optional context.
129
+ """
130
+ if context:
131
+ full_input = (
132
+ "Based on the following context from the XBRL specifications, "
133
+ "please answer the question.\n\n"
134
+ f"[Context]\n{context}\n\n"
135
+ f"[Question]\n{prompt}\n\n"
136
+ "[Answer]\n"
137
+ )
138
+ else:
139
+ full_input = f"Question: {prompt}\nAnswer:\n"
140
+
141
+ inputs = tokenizer(
142
+ full_input,
143
+ return_tensors="pt",
144
+ truncation=True,
145
+ max_length=tokenizer.model_max_length - max_new_tokens
146
+ ).to(model.device)
147
+
148
+ pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
149
+
150
+ with torch.no_grad():
151
+ outputs = model.generate(
152
+ **inputs,
153
+ max_new_tokens=max_new_tokens,
154
+ temperature=temperature,
155
+ top_p=0.9,
156
+ do_sample=True,
157
+ pad_token_id=pad_token_id,
158
+ eos_token_id=tokenizer.eos_token_id
159
+ )
160
+
161
+ input_len = inputs["input_ids"].shape[1]
162
+ new_tokens = outputs[0][input_len:]
163
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
164
+ return response
165
+
166
+
167
+ # ================================================================
168
+ # 5️⃣ The RAG Inference Pipeline
169
+ # ================================================================
170
+ def xbrl_compliance_agent(query: str, model, tokenizer, kb: list):
171
+ """
172
+ Full pipeline:
173
+ 1. Retrieve context from local XBRL knowledge base.
174
+ 2. Generate answer using Fin-01-8B model.
175
+ """
176
+ print(f"\n🔹 Retrieving context for: '{query}'...")
177
+ context = retrieve_from_xbrl_database(query, kb, top_k=2)
178
+ if context:
179
+ print("✅ Context retrieval complete.")
180
+ else:
181
+ print("⚠️ No relevant context found.")
182
+
183
+ print("🔹 Generating response from Fin-01-8B...")
184
+ answer = generate_response(model, tokenizer, query, context)
185
+ return answer
186
+
187
+
188
+ # -----------------------------
189
+ # 6️⃣ Example Run
190
+ # -----------------------------
191
+ if __name__ == "__main__":
192
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
193
+
194
+ # 1️⃣ 加载模型
195
+ try:
196
+ tokenizer, model = load_model("Fin-01-8B")
197
+ except Exception as e:
198
+ print(f"❌ 模型加载失败: {e}")
199
+ exit(1)
200
+
201
+ # 2️⃣ 加载知识库
202
+ try:
203
+ kb = load_knowledge_base("xbrl_results_2_spec_filtered_reindexed.json")
204
+ except Exception as e:
205
+ print(f"❌ 知识库加载失败: {e}")
206
+ exit(1)
207
+
208
+ print("\n🧠 FinGPT Compliance Agent 已启动,输入 'exit' 退出。\n")
209
+
210
+ # 3️⃣ 交互问答
211
+ while True:
212
+ query = input("请输入关于XBRL合规的问题:").strip()
213
+ if query.lower() in ["exit", "quit"]:
214
+ print("👋 退出程序。")
215
+ break
216
+ if not query:
217
+ continue
218
+
219
+ try:
220
+ result = xbrl_compliance_agent(query, model, tokenizer, kb)
221
+ print("\n=== AI 回复 ===\n")
222
+ print(result)
223
+ print("\n" + "=" * 40 + "\n")
224
+ except Exception as e:
225
  print(f"❌ 推理出错: {e}\n")