fengxb30 commited on
Commit
62c4d20
·
verified ·
1 Parent(s): 16b0746

Delete FinGPT_TaskII_Submission/scripts/inference_with_tools.py

Browse files
FinGPT_TaskII_Submission/scripts/inference_with_tools.py DELETED
@@ -1,108 +0,0 @@
1
- # ================================================================
2
- # File: inference_with_tools.py
3
- # Author: fengxb30
4
- # Description: Inference script for FinGPT Task II (Compliance Agents)
5
- # ================================================================
6
-
7
- import torch
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
- import requests
10
- import yfinance as yf
11
- from datetime import datetime
12
- import json
13
- import os
14
-
15
- # -----------------------------
16
- # 1️⃣ Load the model
17
- # -----------------------------
18
- def load_model(model_name_or_path="fengxb30/FinGPT_TaskII_Compliance"):
19
- print(f"🔹 Loading model from {model_name_or_path} ...")
20
- tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
21
- model = AutoModelForCausalLM.from_pretrained(
22
- model_name_or_path,
23
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
24
- device_map="auto"
25
- )
26
- model.eval()
27
- return tokenizer, model
28
-
29
- # -----------------------------
30
- # 2️⃣ Financial Data API Helper
31
- # -----------------------------
32
- def fetch_financial_data(ticker="AAPL"):
33
- """Fetch real-time market data from Yahoo Finance"""
34
- stock = yf.Ticker(ticker)
35
- info = stock.info
36
- return {
37
- "symbol": ticker,
38
- "price": info.get("currentPrice", None),
39
- "marketCap": info.get("marketCap", None),
40
- "fiftyTwoWeekHigh": info.get("fiftyTwoWeekHigh", None),
41
- "fiftyTwoWeekLow": info.get("fiftyTwoWeekLow", None),
42
- }
43
-
44
- # -----------------------------
45
- # 3️⃣ External Tool: RAG-style retrieval
46
- # -----------------------------
47
- def retrieve_context_from_web(query):
48
- """Use simple web API (DuckDuckGo Instant Answer) to retrieve context"""
49
- url = f"https://api.duckduckgo.com/?q={query}&format=json"
50
- try:
51
- res = requests.get(url, timeout=10)
52
- data = res.json()
53
- return data.get("AbstractText") or data.get("Heading") or "No context found."
54
- except Exception as e:
55
- return f"Retrieval failed: {e}"
56
-
57
- # -----------------------------
58
- # 4️⃣ Model inference with context
59
- # -----------------------------
60
- def generate_response(model, tokenizer, prompt, context=None, temperature=0.2, max_new_tokens=512):
61
- if context:
62
- full_input = f"Context: {context}\n\nQuestion: {prompt}\nAnswer:"
63
- else:
64
- full_input = f"Question: {prompt}\nAnswer:"
65
- inputs = tokenizer(full_input, return_tensors="pt").to(model.device)
66
- outputs = model.generate(
67
- **inputs,
68
- max_new_tokens=max_new_tokens,
69
- temperature=temperature,
70
- top_p=0.9,
71
- do_sample=True
72
- )
73
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
74
-
75
- # -----------------------------
76
- # 5️⃣ Pipeline function
77
- # -----------------------------
78
- def compliance_agent_pipeline(prompt, ticker=None, use_web=True):
79
- tokenizer, model = load_model()
80
- context = ""
81
-
82
- # Optional: integrate external tools
83
- if ticker:
84
- fin_data = fetch_financial_data(ticker)
85
- context += f"Market data for {ticker}: {json.dumps(fin_data, indent=2)}\n\n"
86
-
87
- if use_web:
88
- web_context = retrieve_context_from_web(prompt)
89
- context += f"Retrieved context: {web_context}\n\n"
90
-
91
- # Generate model output
92
- response = generate_response(model, tokenizer, prompt, context)
93
- return response
94
-
95
-
96
- # -----------------------------
97
- # 6️⃣ Example run
98
- # -----------------------------
99
- if __name__ == "__main__":
100
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
101
- print("🧠 FinGPT Task II - Compliance Agent Inference")
102
- query = input("Enter your compliance query: ")
103
-
104
- # Example: "Summarize the compliance risk of Meta in 2025"
105
- result = compliance_agent_pipeline(prompt=query, ticker="META", use_web=True)
106
-
107
- print("\n=== AI Compliance Agent Response ===\n")
108
- print(result)