fengxb30 commited on
Commit
90c44e5
·
verified ·
1 Parent(s): aab3065

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +225 -0
inference.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ================================================================
2
+ # File: inference.py
3
+ # Description:
4
+ # Inference script for FinGPT Task II (Compliance Agents)
5
+ # using Hugging Face model "Fin-01-8B" and local XBRL knowledge base.
6
+ # ================================================================
7
+
8
+ import os
9
+ import re
10
+ import json
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+
14
+
15
+ # ================================================================
16
+ # 1️⃣ Load the Hugging Face Model (Fin-01-8B)
17
+ # ================================================================
18
+ def load_model(model_name_or_path="fengxb30/Fin-01-8B"):
19
+ """
20
+ Loads the tokenizer and causal LM from Hugging Face Hub (Fin-01-8B).
21
+ Automatically sets device, dtype, and pad_token.
22
+ """
23
+ print(f"🔹 Loading model from Hugging Face: '{model_name_or_path}'...")
24
+
25
+ try:
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
27
+ except Exception as e:
28
+ raise RuntimeError(f"❌ Failed to load tokenizer: {e}")
29
+
30
+ # Ensure pad_token exists
31
+ if tokenizer.pad_token_id is None:
32
+ tokenizer.pad_token = tokenizer.eos_token or "[PAD]"
33
+
34
+ try:
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_name_or_path,
37
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
38
+ device_map="auto",
39
+ low_cpu_mem_usage=True
40
+ )
41
+ except Exception as e:
42
+ raise RuntimeError(f"❌ Failed to load model weights: {e}")
43
+
44
+ model.eval()
45
+ print(f"✅ Model '{model_name_or_path}' loaded successfully.")
46
+ return tokenizer, model
47
+
48
+
49
+ # ================================================================
50
+ # 2️⃣ Load Local XBRL Knowledge Base
51
+ # ================================================================
52
+ def load_knowledge_base(kb_path="xbrl_results_2_spec_filtered_reindexed.json"):
53
+ """
54
+ Loads local JSON knowledge base for Retrieval-Augmented Generation.
55
+ """
56
+ print("🔹 Loading local XBRL knowledge base...")
57
+ if not os.path.exists(kb_path):
58
+ raise FileNotFoundError(f"❌ Knowledge base not found at '{kb_path}'.")
59
+ with open(kb_path, "r", encoding="utf-8") as f:
60
+ kb = json.load(f)
61
+ if not isinstance(kb, list):
62
+ raise ValueError("❌ Knowledge base JSON must be a list of documents.")
63
+ print(f"✅ Knowledge base loaded successfully with {len(kb)} entries.")
64
+ return kb
65
+
66
+
67
+ # ================================================================
68
+ # 3️⃣ New Tool: Retrieval from Local XBRL Knowledge Base
69
+ # ================================================================
70
+ def _tokenize(text: str):
71
+ """Lightweight tokenizer for keyword retrieval."""
72
+ return re.findall(r"\w+", text.lower())
73
+
74
+
75
+ def retrieve_from_xbrl_database(query: str, kb: list, top_k: int = 2, max_chars: int = 1500) -> str:
76
+ """
77
+ Retrieves top-k relevant context snippets from the local XBRL KB.
78
+ Uses a simple keyword-matching retrieval algorithm.
79
+ """
80
+ if not kb:
81
+ return ""
82
+
83
+ query_words = set(_tokenize(query))
84
+ scores = []
85
+
86
+ for doc in kb:
87
+ title = doc.get("title", "")
88
+ text = doc.get("text", "")
89
+ title_words = set(_tokenize(title))
90
+ text_words = set(_tokenize(text))
91
+ score = len(query_words & title_words) * 3 + len(query_words & text_words)
92
+ if score > 0:
93
+ scores.append((score, doc))
94
+
95
+ if not scores:
96
+ return ""
97
+
98
+ # Sort documents by score in descending order
99
+ scores.sort(key=lambda x: x[0], reverse=True)
100
+ top_docs = [d for _, d in scores[:top_k]]
101
+
102
+ # Format the top_k results as context
103
+ context = ""
104
+ for doc in top_docs:
105
+ snippet = (doc.get("text") or "")[:max_chars]
106
+ context += (
107
+ f"Source: {doc.get('url', 'N/A')}\n"
108
+ f"Title: {doc.get('title', 'Untitled')}\n\n"
109
+ f"Snippet: {snippet}\n\n"
110
+ "---\n\n"
111
+ )
112
+
113
+ return context.strip()
114
+
115
+
116
+ # ================================================================
117
+ # 4️⃣ Model Inference with Context (RAG)
118
+ # ================================================================
119
+ def generate_response(
120
+ model,
121
+ tokenizer,
122
+ prompt: str,
123
+ context: str = None,
124
+ temperature: float = 0.2,
125
+ max_new_tokens: int = 512,
126
+ ) -> str:
127
+ """
128
+ Generates a response using Fin-01-8B model given prompt and optional context.
129
+ """
130
+ if context:
131
+ full_input = (
132
+ "Based on the following context from the XBRL specifications, "
133
+ "please answer the question.\n\n"
134
+ f"[Context]\n{context}\n\n"
135
+ f"[Question]\n{prompt}\n\n"
136
+ "[Answer]\n"
137
+ )
138
+ else:
139
+ full_input = f"Question: {prompt}\nAnswer:\n"
140
+
141
+ inputs = tokenizer(
142
+ full_input,
143
+ return_tensors="pt",
144
+ truncation=True,
145
+ max_length=tokenizer.model_max_length - max_new_tokens
146
+ ).to(model.device)
147
+
148
+ pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
149
+
150
+ with torch.no_grad():
151
+ outputs = model.generate(
152
+ **inputs,
153
+ max_new_tokens=max_new_tokens,
154
+ temperature=temperature,
155
+ top_p=0.9,
156
+ do_sample=True,
157
+ pad_token_id=pad_token_id,
158
+ eos_token_id=tokenizer.eos_token_id
159
+ )
160
+
161
+ input_len = inputs["input_ids"].shape[1]
162
+ new_tokens = outputs[0][input_len:]
163
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
164
+ return response
165
+
166
+
167
+ # ================================================================
168
+ # 5️⃣ The RAG Inference Pipeline
169
+ # ================================================================
170
+ def xbrl_compliance_agent(query: str, model, tokenizer, kb: list):
171
+ """
172
+ Full pipeline:
173
+ 1. Retrieve context from local XBRL knowledge base.
174
+ 2. Generate answer using Fin-01-8B model.
175
+ """
176
+ print(f"\n🔹 Retrieving context for: '{query}'...")
177
+ context = retrieve_from_xbrl_database(query, kb, top_k=2)
178
+ if context:
179
+ print("✅ Context retrieval complete.")
180
+ else:
181
+ print("⚠️ No relevant context found.")
182
+
183
+ print("🔹 Generating response from Fin-01-8B...")
184
+ answer = generate_response(model, tokenizer, query, context)
185
+ return answer
186
+
187
+
188
+ # -----------------------------
189
+ # 6️⃣ Example Run
190
+ # -----------------------------
191
+ if __name__ == "__main__":
192
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
193
+
194
+ # 1️⃣ 加载模型
195
+ try:
196
+ tokenizer, model = load_model("fengxb30/Fin-01-8B")
197
+ except Exception as e:
198
+ print(f"❌ 模型加载失败: {e}")
199
+ exit(1)
200
+
201
+ # 2️⃣ 加载知识库
202
+ try:
203
+ kb = load_knowledge_base("xbrl_results_2_spec_filtered_reindexed.json")
204
+ except Exception as e:
205
+ print(f"❌ 知识库加载失败: {e}")
206
+ exit(1)
207
+
208
+ print("\n🧠 FinGPT Compliance Agent 已启动,输入 'exit' 退出。\n")
209
+
210
+ # 3️⃣ 交互问答
211
+ while True:
212
+ query = input("请输入关于XBRL合规的问题:").strip()
213
+ if query.lower() in ["exit", "quit"]:
214
+ print("👋 退出程序。")
215
+ break
216
+ if not query:
217
+ continue
218
+
219
+ try:
220
+ result = xbrl_compliance_agent(query, model, tokenizer, kb)
221
+ print("\n=== AI 回复 ===\n")
222
+ print(result)
223
+ print("\n" + "=" * 40 + "\n")
224
+ except Exception as e:
225
+ print(f"❌ 推理出错: {e}\n")