Spaces:
Sleeping
Sleeping
File size: 3,068 Bytes
5c7385e 2ce05e0 5c7385e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
from typing import Dict, Any, Optional
from .models import load_bert_model, load_llama_model, BERTModel, LlamaModelWrapper
from .bias_analyzer import BiasAnalyzer
class ModelManager:
"""Manages loading and caching of financial models"""
def __init__(self):
self.loaded_models = {}
self.model_configs = {
"FinBERT": {
"model_id": "ProsusAI/finbert",
"type": "bert"
},
"DeBERTa-v3": {
"model_id": "mrm8488/deberta-v3-ft-financial-news-sentiment-analysis",
"type": "bert"
},
"DistilRoBERTa": {
"model_id": "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis",
"type": "bert"
},
# "FinMA": {
# "model_id": "ChanceFocus/finma-7b-full",
# "tokenizer_id": "ChanceFocus/finma-7b-full",
# "type": "llama"
# },
# "FinGPT": {
# "model_id": "oliverwang15/FinGPT_v32_Llama2_Sentiment_Instruction_LoRA_FT",
# "tokenizer_id": "meta-llama/Llama-2-7b-chat-hf",
# "type": "llama"
# }
}
# Label IDs for Llama models
self.label_ids = {
"Positive": [6374],
"Negative": [8178, 22198],
"Neutral": [21104]
}
def load_model(self, model_name: str) -> tuple:
"""Load and cache a model"""
if model_name in self.loaded_models:
return self.loaded_models[model_name]
config = self.model_configs[model_name]
try:
if config["type"] == "bert":
model, tokenizer = load_bert_model(config["model_id"])
wrapped_model = BERTModel(model, tokenizer)
elif config["type"] == "llama":
model, tokenizer = load_llama_model(
base_tokenizer_id=config["tokenizer_id"],
model_id=config["model_id"],
cache_dir="./cache"
)
wrapped_model = LlamaModelWrapper(model, tokenizer, self.label_ids)
# Cache the loaded model
self.loaded_models[model_name] = (wrapped_model, tokenizer)
return wrapped_model, tokenizer
except Exception as e:
raise Exception(f"Failed to load {model_name}: {str(e)}")
def get_bias_analyzer(self, model_name: str) -> BiasAnalyzer:
"""Get a BiasAnalyzer for the specified model"""
wrapped_model, tokenizer = self.load_model(model_name)
# Create BiasAnalyzer with the wrapped model
analyzer = BiasAnalyzer(
model=wrapped_model,
tokenizer=tokenizer,
model_type=self.model_configs[model_name]["type"],
splitter_type='string',
batch_size=16,
is_wrapped=True
)
return analyzer
|