File size: 1,879 Bytes
bf7f566
6248d60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6166234
bf7f566
 
 
7f3e321
defe2fb
bf7f566
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import os
import requests

def download_model_if_not_exists(model_url, destination_dir):
    """Downloads the model from the given URL if it doesn't exist locally.

    Args:
        model_url (str): The URL of the model to download.
        destination_dir (str): The directory to download the model to.

    Returns:
        bool: True if the model was downloaded, False if it already existed.
    """

    if not os.path.exists(destination_dir):
        print(f"Model directory '{destination_dir}' not found. Creating it.")
        os.makedirs(destination_dir)

    model_path = os.path.join(destination_dir, "entity_model2.pt")  # Assuming model filename
    if os.path.exists(model_path):
        print(f"Model already exists at '{model_path}'. Skipping download.")
        return False

    print(f"Downloading model from '{model_url}' to '{model_path}'...")
    try:
        response = requests.get(model_url, stream=True)
        response.raise_for_status()  # Raise an exception for non-2xx status codes
        with open(model_path, 'wb') as f:
            for chunk in response.iter_content(1024):
                f.write(chunk)
        print("Download complete.")
        return True
    except requests.exceptions.RequestException as e:
        print(f"Error downloading model: {e}")
        return False

# Assuming your model URL is https://huggingface.co/rajaatif786/TickerExtraction
model_url = "https://huggingface.co/rajaatif786/TickerExtraction"
model_dir = "./TickerExtraction"  # Change this if needed

downloaded = download_model_if_not_exists(model_url, model_dir)



import pandas as pd
import numpy as np
#os.chdir("./TickerExtraction")
print(os.listdir())
from EntityExtractor import EntityDataset, EntityBertNet,BertEntityExtractor, LABEL_MAP
import nltk
nltk.download('stopwords')
entity_extractor = BertEntityExtractor.load_trained_model()