| |
| """ |
| FinLoRA: Financial Large Language Models with LoRA Adaptation |
| Main inference script for Hugging Face submission |
| |
| This script provides easy loading and inference for all FinLoRA models. |
| """ |
|
|
| import torch |
| import os |
| import json |
| import warnings |
| from typing import Dict, List, Optional, Any, Union |
| from pathlib import Path |
|
|
| |
| warnings.filterwarnings('ignore') |
|
|
| try: |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| from peft import PeftModel |
| except ImportError as e: |
| print(f"Missing required dependencies: {e}") |
| print("Please install: pip install transformers peft bitsandbytes") |
| exit(1) |
|
|
| class FinLoRAPredictor: |
| """Main FinLoRA predictor class""" |
| |
| def __init__(self, |
| model_name: str = "sentiment_llama_3_1_8b_8bits_r8", |
| base_model: str = "meta-llama/Llama-3.1-8B-Instruct", |
| use_4bit: bool = False): |
| """ |
| Initialize FinLoRA predictor |
| |
| Args: |
| model_name: Name of the LoRA model to load |
| base_model: Base model name |
| use_4bit: Whether to use 4-bit quantized models |
| """ |
| self.model_name = model_name |
| self.base_model = base_model |
| self.use_4bit = use_4bit |
| |
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {self.device}") |
| |
| |
| self.model = None |
| self.tokenizer = None |
| |
| |
| self._load_model() |
| |
| def _load_model(self): |
| """Load the FinLoRA model""" |
| try: |
| print(f"Loading model: {self.model_name}") |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(self.base_model) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| |
| if self.device == "cuda": |
| if self.use_4bit: |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16 |
| ) |
| else: |
| bnb_config = BitsAndBytesConfig( |
| load_in_8bit=True, |
| llm_int8_threshold=6.0 |
| ) |
| |
| base_model = AutoModelForCausalLM.from_pretrained( |
| self.base_model, |
| quantization_config=bnb_config, |
| device_map="auto", |
| torch_dtype=torch.float16, |
| trust_remote_code=True |
| ) |
| else: |
| |
| base_model = AutoModelForCausalLM.from_pretrained( |
| self.base_model, |
| device_map="cpu", |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True |
| ) |
| |
| |
| model_dir = "models_4bit" if self.use_4bit else "models" |
| model_path = f"{model_dir}/{self.model_name}" |
| |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model path not found: {model_path}") |
| |
| self.model = PeftModel.from_pretrained(base_model, model_path) |
| self.model.eval() |
| |
| print(f"Model loaded successfully: {self.model_name}") |
| |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| raise |
| |
| def predict(self, |
| text: str, |
| max_length: int = 256, |
| temperature: float = 0.7) -> str: |
| """ |
| Generate prediction for given text |
| |
| Args: |
| text: Input text |
| max_length: Maximum length of generated text |
| temperature: Sampling temperature |
| """ |
| try: |
| |
| inputs = self.tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512 |
| ) |
| |
| if self.device == "cuda": |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_length, |
| do_sample=True, |
| temperature=temperature, |
| top_p=0.9, |
| pad_token_id=self.tokenizer.eos_token_id, |
| eos_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| if text in response: |
| response = response.replace(text, "").strip() |
| |
| return response |
| |
| except Exception as e: |
| print(f"Prediction error: {e}") |
| return f"Error: {str(e)}" |
| |
| def classify_sentiment(self, text: str) -> str: |
| """Classify financial sentiment""" |
| prompt = f"Classify the sentiment of this financial text as positive, negative, or neutral:\n\nText: {text}\n\nSentiment:" |
| response = self.predict(prompt, max_length=10) |
| |
| |
| if 'positive' in response.lower(): |
| return "positive" |
| elif 'negative' in response.lower(): |
| return "negative" |
| else: |
| return "neutral" |
| |
| def extract_entities(self, text: str) -> str: |
| """Extract financial entities""" |
| prompt = f"Extract financial entities from the following text:\n\nText: {text}\n\nEntities:" |
| return self.predict(prompt, max_length=100) |
| |
| def classify_headline(self, headline: str) -> str: |
| """Classify financial headline""" |
| prompt = f"Classify this financial headline as positive or negative:\n\nHeadline: {headline}\n\nSentiment:" |
| response = self.predict(prompt, max_length=10) |
| |
| if 'positive' in response.lower() or 'yes' in response.lower(): |
| return "positive" |
| else: |
| return "negative" |
| |
| def extract_xbrl_tags(self, text: str) -> str: |
| """Extract XBRL tags from financial text""" |
| prompt = f"Extract XBRL tags from the following financial statement:\n\nStatement: {text}\n\nXBRL Tags:" |
| return self.predict(prompt, max_length=100) |
| |
| def process_financial_text(self, text: str) -> str: |
| """Process general financial text""" |
| prompt = f"Analyze this financial text and provide insights:\n\nText: {text}\n\nAnalysis:" |
| return self.predict(prompt, max_length=200) |
|
|
| def list_available_models(use_4bit: bool = False) -> List[str]: |
| """List all available models""" |
| model_dir = "models_4bit" if use_4bit else "models" |
| models_path = Path(model_dir) |
| |
| if not models_path.exists(): |
| return [] |
| |
| models = [] |
| for model_dir in models_path.iterdir(): |
| if model_dir.is_dir() and (model_dir / "adapter_config.json").exists(): |
| models.append(model_dir.name) |
| |
| return sorted(models) |
|
|
| def main(): |
| """Main function for testing the model""" |
| print("=== FinLoRA Financial Language Model ===") |
| print("Loading model and testing inference...") |
| |
| |
| available_models_8bit = list_available_models(use_4bit=False) |
| available_models_4bit = list_available_models(use_4bit=True) |
| |
| print(f"Available 8-bit models: {', '.join(available_models_8bit)}") |
| print(f"Available 4-bit models: {', '.join(available_models_4bit)}") |
| |
| if not available_models_8bit and not available_models_4bit: |
| print("No models found in 'models' or 'models_4bit' directories") |
| return |
| |
| |
| if available_models_8bit: |
| model_name = available_models_8bit[0] |
| use_4bit = False |
| else: |
| model_name = available_models_4bit[0] |
| use_4bit = True |
| |
| print(f"Loading model: {model_name} ({'4-bit' if use_4bit else '8-bit'})") |
| |
| try: |
| |
| predictor = FinLoRAPredictor( |
| model_name=model_name, |
| use_4bit=use_4bit |
| ) |
| |
| |
| test_cases = [ |
| { |
| "task": "Sentiment Analysis", |
| "text": "The company's quarterly earnings exceeded expectations by 20%.", |
| "method": predictor.classify_sentiment |
| }, |
| { |
| "task": "Entity Extraction", |
| "text": "Apple Inc. reported revenue of $394.3 billion in 2022.", |
| "method": predictor.extract_entities |
| }, |
| { |
| "task": "Headline Classification", |
| "text": "Federal Reserve announces interest rate cut", |
| "method": predictor.classify_headline |
| }, |
| { |
| "task": "XBRL Tag Extraction", |
| "text": "Total assets: $1,234,567,890. Current assets: $456,789,123.", |
| "method": predictor.extract_xbrl_tags |
| } |
| ] |
| |
| |
| for i, test_case in enumerate(test_cases, 1): |
| print(f"\n--- Test {i}: {test_case['task']} ---") |
| print(f"Input: {test_case['text']}") |
| |
| try: |
| result = test_case['method'](test_case['text']) |
| print(f"Output: {result}") |
| except Exception as e: |
| print(f"Error: {e}") |
| |
| print("\nModel testing completed successfully!") |
| |
| except Exception as e: |
| print(f"Error: {e}") |
| print("\nTroubleshooting:") |
| print("1. Ensure all model files are in the 'models' or 'models_4bit' directory") |
| print("2. Check that the base model can be downloaded") |
| print("3. Verify CUDA availability if using GPU") |
|
|
| if __name__ == "__main__": |
| main() |