|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(model_name_or_path="Fin-01-8B"): |
|
|
""" |
|
|
Loads the tokenizer and causal LM from Hugging Face Hub (Fin-01-8B). |
|
|
Automatically sets device, dtype, and pad_token. |
|
|
""" |
|
|
print(f"🔹 Loading model from Hugging Face: '{model_name_or_path}'...") |
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"❌ Failed to load tokenizer: {e}") |
|
|
|
|
|
|
|
|
if tokenizer.pad_token_id is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token or "[PAD]" |
|
|
|
|
|
try: |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name_or_path, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"❌ Failed to load model weights: {e}") |
|
|
|
|
|
model.eval() |
|
|
print(f"✅ Model '{model_name_or_path}' loaded successfully.") |
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_knowledge_base(kb_path="xbrl_results_2_spec_filtered_reindexed.json"): |
|
|
""" |
|
|
Loads local JSON knowledge base for Retrieval-Augmented Generation. |
|
|
""" |
|
|
print("🔹 Loading local XBRL knowledge base...") |
|
|
if not os.path.exists(kb_path): |
|
|
raise FileNotFoundError(f"❌ Knowledge base not found at '{kb_path}'.") |
|
|
with open(kb_path, "r", encoding="utf-8") as f: |
|
|
kb = json.load(f) |
|
|
if not isinstance(kb, list): |
|
|
raise ValueError("❌ Knowledge base JSON must be a list of documents.") |
|
|
print(f"✅ Knowledge base loaded successfully with {len(kb)} entries.") |
|
|
return kb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tokenize(text: str): |
|
|
"""Lightweight tokenizer for keyword retrieval.""" |
|
|
return re.findall(r"\w+", text.lower()) |
|
|
|
|
|
|
|
|
def retrieve_from_xbrl_database(query: str, kb: list, top_k: int = 2, max_chars: int = 1500) -> str: |
|
|
""" |
|
|
Retrieves top-k relevant context snippets from the local XBRL KB. |
|
|
Uses a simple keyword-matching retrieval algorithm. |
|
|
""" |
|
|
if not kb: |
|
|
return "" |
|
|
|
|
|
query_words = set(_tokenize(query)) |
|
|
scores = [] |
|
|
|
|
|
for doc in kb: |
|
|
title = doc.get("title", "") |
|
|
text = doc.get("text", "") |
|
|
title_words = set(_tokenize(title)) |
|
|
text_words = set(_tokenize(text)) |
|
|
score = len(query_words & title_words) * 3 + len(query_words & text_words) |
|
|
if score > 0: |
|
|
scores.append((score, doc)) |
|
|
|
|
|
if not scores: |
|
|
return "" |
|
|
|
|
|
|
|
|
scores.sort(key=lambda x: x[0], reverse=True) |
|
|
top_docs = [d for _, d in scores[:top_k]] |
|
|
|
|
|
|
|
|
context = "" |
|
|
for doc in top_docs: |
|
|
snippet = (doc.get("text") or "")[:max_chars] |
|
|
context += ( |
|
|
f"Source: {doc.get('url', 'N/A')}\n" |
|
|
f"Title: {doc.get('title', 'Untitled')}\n\n" |
|
|
f"Snippet: {snippet}\n\n" |
|
|
"---\n\n" |
|
|
) |
|
|
|
|
|
return context.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_response( |
|
|
model, |
|
|
tokenizer, |
|
|
prompt: str, |
|
|
context: str = None, |
|
|
temperature: float = 0.2, |
|
|
max_new_tokens: int = 512, |
|
|
) -> str: |
|
|
""" |
|
|
Generates a response using Fin-01-8B model given prompt and optional context. |
|
|
""" |
|
|
if context: |
|
|
full_input = ( |
|
|
"Based on the following context from the XBRL specifications, " |
|
|
"please answer the question.\n\n" |
|
|
f"[Context]\n{context}\n\n" |
|
|
f"[Question]\n{prompt}\n\n" |
|
|
"[Answer]\n" |
|
|
) |
|
|
else: |
|
|
full_input = f"Question: {prompt}\nAnswer:\n" |
|
|
|
|
|
inputs = tokenizer( |
|
|
full_input, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=tokenizer.model_max_length - max_new_tokens |
|
|
).to(model.device) |
|
|
|
|
|
pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
pad_token_id=pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
input_len = inputs["input_ids"].shape[1] |
|
|
new_tokens = outputs[0][input_len:] |
|
|
response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def xbrl_compliance_agent(query: str, model, tokenizer, kb: list): |
|
|
""" |
|
|
Full pipeline: |
|
|
1. Retrieve context from local XBRL knowledge base. |
|
|
2. Generate answer using Fin-01-8B model. |
|
|
""" |
|
|
print(f"\n🔹 Retrieving context for: '{query}'...") |
|
|
context = retrieve_from_xbrl_database(query, kb, top_k=2) |
|
|
if context: |
|
|
print("✅ Context retrieval complete.") |
|
|
else: |
|
|
print("⚠️ No relevant context found.") |
|
|
|
|
|
print("🔹 Generating response from Fin-01-8B...") |
|
|
answer = generate_response(model, tokenizer, query, context) |
|
|
return answer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer, model = load_model("Fin-01-8B") |
|
|
except Exception as e: |
|
|
print(f"❌ 模型加载失败: {e}") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
try: |
|
|
kb = load_knowledge_base("xbrl_results_2_spec_filtered_reindexed.json") |
|
|
except Exception as e: |
|
|
print(f"❌ 知识库加载失败: {e}") |
|
|
exit(1) |
|
|
|
|
|
print("\n🧠 FinGPT Compliance Agent 已启动,输入 'exit' 退出。\n") |
|
|
|
|
|
|
|
|
while True: |
|
|
query = input("请输入关于XBRL合规的问题:").strip() |
|
|
if query.lower() in ["exit", "quit"]: |
|
|
print("👋 退出程序。") |
|
|
break |
|
|
if not query: |
|
|
continue |
|
|
|
|
|
try: |
|
|
result = xbrl_compliance_agent(query, model, tokenizer, kb) |
|
|
print("\n=== AI 回复 ===\n") |
|
|
print(result) |
|
|
print("\n" + "=" * 40 + "\n") |
|
|
except Exception as e: |
|
|
print(f"❌ 推理出错: {e}\n") |