Ticker / load_model.py
rajaatif786's picture
Update load_model.py
7f3e321 verified
import os
import os
import requests
def download_model_if_not_exists(model_url, destination_dir):
"""Downloads the model from the given URL if it doesn't exist locally.
Args:
model_url (str): The URL of the model to download.
destination_dir (str): The directory to download the model to.
Returns:
bool: True if the model was downloaded, False if it already existed.
"""
if not os.path.exists(destination_dir):
print(f"Model directory '{destination_dir}' not found. Creating it.")
os.makedirs(destination_dir)
model_path = os.path.join(destination_dir, "entity_model2.pt") # Assuming model filename
if os.path.exists(model_path):
print(f"Model already exists at '{model_path}'. Skipping download.")
return False
print(f"Downloading model from '{model_url}' to '{model_path}'...")
try:
response = requests.get(model_url, stream=True)
response.raise_for_status() # Raise an exception for non-2xx status codes
with open(model_path, 'wb') as f:
for chunk in response.iter_content(1024):
f.write(chunk)
print("Download complete.")
return True
except requests.exceptions.RequestException as e:
print(f"Error downloading model: {e}")
return False
# Assuming your model URL is https://huggingface.co/rajaatif786/TickerExtraction
model_url = "https://huggingface.co/rajaatif786/TickerExtraction"
model_dir = "./TickerExtraction" # Change this if needed
downloaded = download_model_if_not_exists(model_url, model_dir)
import pandas as pd
import numpy as np
#os.chdir("./TickerExtraction")
print(os.listdir())
from EntityExtractor import EntityDataset, EntityBertNet,BertEntityExtractor, LABEL_MAP
import nltk
nltk.download('stopwords')
entity_extractor = BertEntityExtractor.load_trained_model()