# ================================================================ # File: inference.py # Description: # Inference script for FinGPT Task II (Compliance Agents) # using Hugging Face model "Fin-01-8B" and local XBRL knowledge base. # ================================================================ import os import re import json import torch from transformers import AutoTokenizer, AutoModelForCausalLM # ================================================================ # 1️⃣ Load the Hugging Face Model (Fin-01-8B) # ================================================================ def load_model(model_name_or_path="Fin-01-8B"): """ Loads the tokenizer and causal LM from Hugging Face Hub (Fin-01-8B). Automatically sets device, dtype, and pad_token. """ print(f"🔹 Loading model from Hugging Face: '{model_name_or_path}'...") try: tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) except Exception as e: raise RuntimeError(f"❌ Failed to load tokenizer: {e}") # Ensure pad_token exists if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token or "[PAD]" try: model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", low_cpu_mem_usage=True ) except Exception as e: raise RuntimeError(f"❌ Failed to load model weights: {e}") model.eval() print(f"✅ Model '{model_name_or_path}' loaded successfully.") return tokenizer, model # ================================================================ # 2️⃣ Load Local XBRL Knowledge Base # ================================================================ def load_knowledge_base(kb_path="xbrl_results_2_spec_filtered_reindexed.json"): """ Loads local JSON knowledge base for Retrieval-Augmented Generation. """ print("🔹 Loading local XBRL knowledge base...") if not os.path.exists(kb_path): raise FileNotFoundError(f"❌ Knowledge base not found at '{kb_path}'.") with open(kb_path, "r", encoding="utf-8") as f: kb = json.load(f) if not isinstance(kb, list): raise ValueError("❌ Knowledge base JSON must be a list of documents.") print(f"✅ Knowledge base loaded successfully with {len(kb)} entries.") return kb # ================================================================ # 3️⃣ New Tool: Retrieval from Local XBRL Knowledge Base # ================================================================ def _tokenize(text: str): """Lightweight tokenizer for keyword retrieval.""" return re.findall(r"\w+", text.lower()) def retrieve_from_xbrl_database(query: str, kb: list, top_k: int = 2, max_chars: int = 1500) -> str: """ Retrieves top-k relevant context snippets from the local XBRL KB. Uses a simple keyword-matching retrieval algorithm. """ if not kb: return "" query_words = set(_tokenize(query)) scores = [] for doc in kb: title = doc.get("title", "") text = doc.get("text", "") title_words = set(_tokenize(title)) text_words = set(_tokenize(text)) score = len(query_words & title_words) * 3 + len(query_words & text_words) if score > 0: scores.append((score, doc)) if not scores: return "" # Sort documents by score in descending order scores.sort(key=lambda x: x[0], reverse=True) top_docs = [d for _, d in scores[:top_k]] # Format the top_k results as context context = "" for doc in top_docs: snippet = (doc.get("text") or "")[:max_chars] context += ( f"Source: {doc.get('url', 'N/A')}\n" f"Title: {doc.get('title', 'Untitled')}\n\n" f"Snippet: {snippet}\n\n" "---\n\n" ) return context.strip() # ================================================================ # 4️⃣ Model Inference with Context (RAG) # ================================================================ def generate_response( model, tokenizer, prompt: str, context: str = None, temperature: float = 0.2, max_new_tokens: int = 512, ) -> str: """ Generates a response using Fin-01-8B model given prompt and optional context. """ if context: full_input = ( "Based on the following context from the XBRL specifications, " "please answer the question.\n\n" f"[Context]\n{context}\n\n" f"[Question]\n{prompt}\n\n" "[Answer]\n" ) else: full_input = f"Question: {prompt}\nAnswer:\n" inputs = tokenizer( full_input, return_tensors="pt", truncation=True, max_length=tokenizer.model_max_length - max_new_tokens ).to(model.device) pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=0.9, do_sample=True, pad_token_id=pad_token_id, eos_token_id=tokenizer.eos_token_id ) input_len = inputs["input_ids"].shape[1] new_tokens = outputs[0][input_len:] response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() return response # ================================================================ # 5️⃣ The RAG Inference Pipeline # ================================================================ def xbrl_compliance_agent(query: str, model, tokenizer, kb: list): """ Full pipeline: 1. Retrieve context from local XBRL knowledge base. 2. Generate answer using Fin-01-8B model. """ print(f"\n🔹 Retrieving context for: '{query}'...") context = retrieve_from_xbrl_database(query, kb, top_k=2) if context: print("✅ Context retrieval complete.") else: print("⚠️ No relevant context found.") print("🔹 Generating response from Fin-01-8B...") answer = generate_response(model, tokenizer, query, context) return answer # ----------------------------- # 6️⃣ Example Run # ----------------------------- if __name__ == "__main__": os.environ["TOKENIZERS_PARALLELISM"] = "false" # 1️⃣ 加载模型 try: tokenizer, model = load_model("Fin-01-8B") except Exception as e: print(f"❌ 模型加载失败: {e}") exit(1) # 2️⃣ 加载知识库 try: kb = load_knowledge_base("xbrl_results_2_spec_filtered_reindexed.json") except Exception as e: print(f"❌ 知识库加载失败: {e}") exit(1) print("\n🧠 FinGPT Compliance Agent 已启动,输入 'exit' 退出。\n") # 3️⃣ 交互问答 while True: query = input("请输入关于XBRL合规的问题:").strip() if query.lower() in ["exit", "quit"]: print("👋 退出程序。") break if not query: continue try: result = xbrl_compliance_agent(query, model, tokenizer, kb) print("\n=== AI 回复 ===\n") print(result) print("\n" + "=" * 40 + "\n") except Exception as e: print(f"❌ 推理出错: {e}\n")