File size: 3,068 Bytes
5c7385e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ce05e0
 
 
 
 
 
 
 
 
 
5c7385e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from typing import Dict, Any, Optional
from .models import load_bert_model, load_llama_model, BERTModel, LlamaModelWrapper
from .bias_analyzer import BiasAnalyzer

class ModelManager:
    """Manages loading and caching of financial models"""
    
    def __init__(self):
        self.loaded_models = {}
        self.model_configs = {
            "FinBERT": {
                "model_id": "ProsusAI/finbert",
                "type": "bert"
            },
            "DeBERTa-v3": {
                "model_id": "mrm8488/deberta-v3-ft-financial-news-sentiment-analysis",
                "type": "bert"
            },
            "DistilRoBERTa": {
                "model_id": "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis",
                "type": "bert"
            },
            # "FinMA": {
            #     "model_id": "ChanceFocus/finma-7b-full",
            #     "tokenizer_id": "ChanceFocus/finma-7b-full",
            #     "type": "llama"
            # },
            # "FinGPT": {
            #     "model_id": "oliverwang15/FinGPT_v32_Llama2_Sentiment_Instruction_LoRA_FT",
            #     "tokenizer_id": "meta-llama/Llama-2-7b-chat-hf", 
            #     "type": "llama"
            # }
        }
        
        # Label IDs for Llama models
        self.label_ids = {
            "Positive": [6374],
            "Negative": [8178, 22198],
            "Neutral": [21104]
        }
    
    def load_model(self, model_name: str) -> tuple:
        """Load and cache a model"""
        if model_name in self.loaded_models:
            return self.loaded_models[model_name]
        
        config = self.model_configs[model_name]
        
        try:
            if config["type"] == "bert":
                model, tokenizer = load_bert_model(config["model_id"])
                wrapped_model = BERTModel(model, tokenizer)
                
            elif config["type"] == "llama":
                model, tokenizer = load_llama_model(
                    base_tokenizer_id=config["tokenizer_id"],
                    model_id=config["model_id"],
                    cache_dir="./cache"
                )
                wrapped_model = LlamaModelWrapper(model, tokenizer, self.label_ids)
            
            # Cache the loaded model
            self.loaded_models[model_name] = (wrapped_model, tokenizer)
            return wrapped_model, tokenizer
            
        except Exception as e:
            raise Exception(f"Failed to load {model_name}: {str(e)}")
    
    def get_bias_analyzer(self, model_name: str) -> BiasAnalyzer:
        """Get a BiasAnalyzer for the specified model"""
        wrapped_model, tokenizer = self.load_model(model_name)
        
        # Create BiasAnalyzer with the wrapped model
        analyzer = BiasAnalyzer(
            model=wrapped_model,
            tokenizer=tokenizer,
            model_type=self.model_configs[model_name]["type"],
            splitter_type='string',
            batch_size=16,
            is_wrapped=True
        )
        
        return analyzer