fengxb30 commited on
Commit
6870358
·
verified ·
1 Parent(s): 975e331

Update FinGPT_TaskII_Submission/scripts/inference_with_tools.py

Browse files
FinGPT_TaskII_Submission/scripts/inference_with_tools.py CHANGED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ================================================================
2
+ # File: inference_with_tools.py
3
+ # Author: fengxb30
4
+ # Description: Inference script for FinGPT Task II (Compliance Agents)
5
+ # ================================================================
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ import requests
10
+ import yfinance as yf
11
+ from datetime import datetime
12
+ import json
13
+ import os
14
+
15
+ # -----------------------------
16
+ # 1️⃣ Load the model
17
+ # -----------------------------
18
+ def load_model(model_name_or_path="fengxb30/FinGPT_TaskII_Compliance"):
19
+ print(f"🔹 Loading model from {model_name_or_path} ...")
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_name_or_path,
23
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
24
+ device_map="auto"
25
+ )
26
+ model.eval()
27
+ return tokenizer, model
28
+
29
+ # -----------------------------
30
+ # 2️⃣ Financial Data API Helper
31
+ # -----------------------------
32
+ def fetch_financial_data(ticker="AAPL"):
33
+ """Fetch real-time market data from Yahoo Finance"""
34
+ stock = yf.Ticker(ticker)
35
+ info = stock.info
36
+ return {
37
+ "symbol": ticker,
38
+ "price": info.get("currentPrice", None),
39
+ "marketCap": info.get("marketCap", None),
40
+ "fiftyTwoWeekHigh": info.get("fiftyTwoWeekHigh", None),
41
+ "fiftyTwoWeekLow": info.get("fiftyTwoWeekLow", None),
42
+ }
43
+
44
+ # -----------------------------
45
+ # 3️⃣ External Tool: RAG-style retrieval
46
+ # -----------------------------
47
+ def retrieve_context_from_web(query):
48
+ """Use simple web API (DuckDuckGo Instant Answer) to retrieve context"""
49
+ url = f"https://api.duckduckgo.com/?q={query}&format=json"
50
+ try:
51
+ res = requests.get(url, timeout=10)
52
+ data = res.json()
53
+ return data.get("AbstractText") or data.get("Heading") or "No context found."
54
+ except Exception as e:
55
+ return f"Retrieval failed: {e}"
56
+
57
+ # -----------------------------
58
+ # 4️⃣ Model inference with context
59
+ # -----------------------------
60
+ def generate_response(model, tokenizer, prompt, context=None, temperature=0.2, max_new_tokens=512):
61
+ if context:
62
+ full_input = f"Context: {context}\n\nQuestion: {prompt}\nAnswer:"
63
+ else:
64
+ full_input = f"Question: {prompt}\nAnswer:"
65
+ inputs = tokenizer(full_input, return_tensors="pt").to(model.device)
66
+ outputs = model.generate(
67
+ **inputs,
68
+ max_new_tokens=max_new_tokens,
69
+ temperature=temperature,
70
+ top_p=0.9,
71
+ do_sample=True
72
+ )
73
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
74
+
75
+ # -----------------------------
76
+ # 5️⃣ Pipeline function
77
+ # -----------------------------
78
+ def compliance_agent_pipeline(prompt, ticker=None, use_web=True):
79
+ tokenizer, model = load_model()
80
+ context = ""
81
+
82
+ # Optional: integrate external tools
83
+ if ticker:
84
+ fin_data = fetch_financial_data(ticker)
85
+ context += f"Market data for {ticker}: {json.dumps(fin_data, indent=2)}\n\n"
86
+
87
+ if use_web:
88
+ web_context = retrieve_context_from_web(prompt)
89
+ context += f"Retrieved context: {web_context}\n\n"
90
+
91
+ # Generate model output
92
+ response = generate_response(model, tokenizer, prompt, context)
93
+ return response
94
+
95
+
96
+ # -----------------------------
97
+ # 6️⃣ Example run
98
+ # -----------------------------
99
+ if __name__ == "__main__":
100
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
101
+ print("🧠 FinGPT Task II - Compliance Agent Inference")
102
+ query = input("Enter your compliance query: ")
103
+
104
+ # Example: "Summarize the compliance risk of Meta in 2025"
105
+ result = compliance_agent_pipeline(prompt=query, ticker="META", use_web=True)
106
+
107
+ print("\n=== AI Compliance Agent Response ===\n")
108
+ print(result)