|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from peft import PeftModel |
|
|
import os |
|
|
|
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from transformers.generation.utils import GenerationConfig |
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
|
|
|
|
|
print("=" * 50) |
|
|
print("开始加载FinGPT情感分析模型...") |
|
|
print("=" * 50) |
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
device = None |
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Go4miii/DISC-FinLLM", use_fast=False, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained("Go4miii/DISC-FinLLM", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True) |
|
|
model.generation_config = GenerationConfig.from_pretrained("Go4miii/DISC-FinLLM") |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
print("\n" + "=" * 50) |
|
|
print("❌ 模型加载失败!") |
|
|
print(f"错误信息: {e}") |
|
|
print("=" * 50) |
|
|
raise |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def analyze_sentiment(news_text): |
|
|
""" |
|
|
分析金融新闻的情感倾向 |
|
|
""" |
|
|
if model is None or tokenizer is None: |
|
|
return "❌ 模型未正确加载,请检查Spaces日志。" |
|
|
|
|
|
try: |
|
|
|
|
|
prompt = f'''Instruction: What is the sentiment of this news? Please choose an answer from {{negative/neutral/positive}} |
|
|
Input: {news_text} |
|
|
Answer: ''' |
|
|
|
|
|
|
|
|
tokens = tokenizer( |
|
|
prompt, |
|
|
return_tensors='pt', |
|
|
padding=True, |
|
|
max_length=512, |
|
|
truncation=True |
|
|
).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
res = model.generate( |
|
|
**tokens, |
|
|
max_length=512, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
res_sentence = tokenizer.decode(res[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "Answer: " in res_sentence: |
|
|
sentiment = res_sentence.split("Answer: ")[1].strip() |
|
|
|
|
|
sentiment = sentiment.split('\n')[0].strip() |
|
|
else: |
|
|
sentiment = res_sentence |
|
|
|
|
|
return sentiment |
|
|
|
|
|
except Exception as e: |
|
|
return f"❌ 分析出错: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 📊 FinGPT 金融新闻情感分析 |
|
|
|
|
|
基于 **FinGPT/fingpt-mt_llama3-8b_lora** 模型的金融新闻情感分析工具。 |
|
|
|
|
|
输入金融新闻文本,AI将分析其情感倾向:**positive(积极)** / **neutral(中性)** / **negative(消极)** |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
news_input = gr.Textbox( |
|
|
label="📰 输入金融新闻", |
|
|
placeholder="粘贴或输入金融新闻内容...", |
|
|
lines=6 |
|
|
) |
|
|
analyze_btn = gr.Button("🔍 分析情感", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
sentiment_output = gr.Textbox( |
|
|
label="😊 情感分析结果", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"FINANCING OF ASPOCOMP 'S GROWTH Aspocomp is aggressively pursuing its growth strategy by increasingly focusing on technologically more demanding HDI printed circuit boards PCBs.", |
|
|
"According to Gran, the company has no plans to move all production to Russia, although that is where the company is growing.", |
|
|
"Apple Inc. reported record quarterly revenue of $123.9 billion, up 11% year over year, and quarterly earnings per diluted share of $2.10.", |
|
|
"The Federal Reserve announced a 0.75 percentage point interest rate increase, the largest since 1994, to combat rising inflation.", |
|
|
"Tesla shares tumbled 12% after the company missed delivery expectations for the third consecutive quarter.", |
|
|
"Microsoft and OpenAI announced a multi-year partnership to develop advanced AI technologies, with Microsoft investing $10 billion.", |
|
|
], |
|
|
inputs=news_input, |
|
|
label="📋 示例新闻(点击使用)" |
|
|
) |
|
|
|
|
|
|
|
|
analyze_btn.click( |
|
|
fn=analyze_sentiment, |
|
|
inputs=news_input, |
|
|
outputs=sentiment_output |
|
|
) |
|
|
|
|
|
news_input.submit( |
|
|
fn=analyze_sentiment, |
|
|
inputs=news_input, |
|
|
outputs=sentiment_output |
|
|
) |
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|