FinGPT_TaskII_Compliance / inference.py
fengxb30's picture
Update inference.py
a8c077f verified
raw
history blame
7.44 kB
# ================================================================
# File: inference.py
# Description:
# Inference script for FinGPT Task II (Compliance Agents)
# using Hugging Face model "Fin-01-8B" and local XBRL knowledge base.
# ================================================================
import os
import re
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# ================================================================
# 1️⃣ Load the Hugging Face Model (Fin-01-8B)
# ================================================================
def load_model(model_name_or_path="Fin-01-8B"):
"""
Loads the tokenizer and causal LM from Hugging Face Hub (Fin-01-8B).
Automatically sets device, dtype, and pad_token.
"""
print(f"🔹 Loading model from Hugging Face: '{model_name_or_path}'...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
except Exception as e:
raise RuntimeError(f"❌ Failed to load tokenizer: {e}")
# Ensure pad_token exists
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token or "[PAD]"
try:
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
low_cpu_mem_usage=True
)
except Exception as e:
raise RuntimeError(f"❌ Failed to load model weights: {e}")
model.eval()
print(f"✅ Model '{model_name_or_path}' loaded successfully.")
return tokenizer, model
# ================================================================
# 2️⃣ Load Local XBRL Knowledge Base
# ================================================================
def load_knowledge_base(kb_path="xbrl_results_2_spec_filtered_reindexed.json"):
"""
Loads local JSON knowledge base for Retrieval-Augmented Generation.
"""
print("🔹 Loading local XBRL knowledge base...")
if not os.path.exists(kb_path):
raise FileNotFoundError(f"❌ Knowledge base not found at '{kb_path}'.")
with open(kb_path, "r", encoding="utf-8") as f:
kb = json.load(f)
if not isinstance(kb, list):
raise ValueError("❌ Knowledge base JSON must be a list of documents.")
print(f"✅ Knowledge base loaded successfully with {len(kb)} entries.")
return kb
# ================================================================
# 3️⃣ New Tool: Retrieval from Local XBRL Knowledge Base
# ================================================================
def _tokenize(text: str):
"""Lightweight tokenizer for keyword retrieval."""
return re.findall(r"\w+", text.lower())
def retrieve_from_xbrl_database(query: str, kb: list, top_k: int = 2, max_chars: int = 1500) -> str:
"""
Retrieves top-k relevant context snippets from the local XBRL KB.
Uses a simple keyword-matching retrieval algorithm.
"""
if not kb:
return ""
query_words = set(_tokenize(query))
scores = []
for doc in kb:
title = doc.get("title", "")
text = doc.get("text", "")
title_words = set(_tokenize(title))
text_words = set(_tokenize(text))
score = len(query_words & title_words) * 3 + len(query_words & text_words)
if score > 0:
scores.append((score, doc))
if not scores:
return ""
# Sort documents by score in descending order
scores.sort(key=lambda x: x[0], reverse=True)
top_docs = [d for _, d in scores[:top_k]]
# Format the top_k results as context
context = ""
for doc in top_docs:
snippet = (doc.get("text") or "")[:max_chars]
context += (
f"Source: {doc.get('url', 'N/A')}\n"
f"Title: {doc.get('title', 'Untitled')}\n\n"
f"Snippet: {snippet}\n\n"
"---\n\n"
)
return context.strip()
# ================================================================
# 4️⃣ Model Inference with Context (RAG)
# ================================================================
def generate_response(
model,
tokenizer,
prompt: str,
context: str = None,
temperature: float = 0.2,
max_new_tokens: int = 512,
) -> str:
"""
Generates a response using Fin-01-8B model given prompt and optional context.
"""
if context:
full_input = (
"Based on the following context from the XBRL specifications, "
"please answer the question.\n\n"
f"[Context]\n{context}\n\n"
f"[Question]\n{prompt}\n\n"
"[Answer]\n"
)
else:
full_input = f"Question: {prompt}\nAnswer:\n"
inputs = tokenizer(
full_input,
return_tensors="pt",
truncation=True,
max_length=tokenizer.model_max_length - max_new_tokens
).to(model.device)
pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
pad_token_id=pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
input_len = inputs["input_ids"].shape[1]
new_tokens = outputs[0][input_len:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
return response
# ================================================================
# 5️⃣ The RAG Inference Pipeline
# ================================================================
def xbrl_compliance_agent(query: str, model, tokenizer, kb: list):
"""
Full pipeline:
1. Retrieve context from local XBRL knowledge base.
2. Generate answer using Fin-01-8B model.
"""
print(f"\n🔹 Retrieving context for: '{query}'...")
context = retrieve_from_xbrl_database(query, kb, top_k=2)
if context:
print("✅ Context retrieval complete.")
else:
print("⚠️ No relevant context found.")
print("🔹 Generating response from Fin-01-8B...")
answer = generate_response(model, tokenizer, query, context)
return answer
# -----------------------------
# 6️⃣ Example Run
# -----------------------------
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 1️⃣ 加载模型
try:
tokenizer, model = load_model("Fin-01-8B")
except Exception as e:
print(f"❌ 模型加载失败: {e}")
exit(1)
# 2️⃣ 加载知识库
try:
kb = load_knowledge_base("xbrl_results_2_spec_filtered_reindexed.json")
except Exception as e:
print(f"❌ 知识库加载失败: {e}")
exit(1)
print("\n🧠 FinGPT Compliance Agent 已启动,输入 'exit' 退出。\n")
# 3️⃣ 交互问答
while True:
query = input("请输入关于XBRL合规的问题:").strip()
if query.lower() in ["exit", "quit"]:
print("👋 退出程序。")
break
if not query:
continue
try:
result = xbrl_compliance_agent(query, model, tokenizer, kb)
print("\n=== AI 回复 ===\n")
print(result)
print("\n" + "=" * 40 + "\n")
except Exception as e:
print(f"❌ 推理出错: {e}\n")