ostock-backend / model /src /data /hierarchical_embedding.py
johnaness's picture
Deploy OStock FastAPI backend to HF Space (Docker SDK, port 7860)
4be2d4d
"""
๊ณ„์ธต์  ์ž„๋ฒ ๋”ฉ ์ฒ˜๋ฆฌ ๊ด€๋ จ ํ•จ์ˆ˜
"""
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import yfinance as yf
def combine_stocks_for_embedding(all_data, tickers):
"""
์ข…๋ชฉ ์ž„๋ฒ ๋”ฉ์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐ ํ†ตํ•ฉ
"""
print("\n์ข…๋ชฉ ์ž„๋ฒ ๋”ฉ์„ ์œ„ํ•œ ๋ฐ์ดํ„ฐ ํ†ตํ•ฉ ์ค‘...")
all_stocks_data = []
for ticker in tickers:
if ticker in all_data:
# ํ˜„์žฌ ์ข…๋ชฉ ๋ฐ์ดํ„ฐ ๋ณต์‚ฌ
stock_df = all_data[ticker].copy()
# ์ข…๋ชฉ ์‹๋ณ„์ž ์ปฌ๋Ÿผ ์ถ”๊ฐ€
stock_df['ticker'] = ticker
# ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€
all_stocks_data.append(stock_df)
print(f"{ticker} ๋ฐ์ดํ„ฐ: {len(stock_df)}ํ–‰ ์ฒ˜๋ฆฌ")
# ๋ชจ๋“  ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์„ ์ˆ˜์ง์œผ๋กœ ๋ณ‘ํ•ฉ
combined_df = pd.concat(all_stocks_data, axis=0)
combined_df = combined_df.sort_index()
combined_df = combined_df.reset_index()
# ticker ์ปฌ๋Ÿผ์„ ์•ž์ชฝ์œผ๋กœ ์ด๋™
cols = combined_df.columns.tolist()
cols.remove('ticker')
cols = ['ticker'] + cols
combined_df = combined_df[cols]
return combined_df
def get_industry_data(tickers):
"""
์ข…๋ชฉ ๋ฆฌ์ŠคํŠธ์˜ ์‚ฐ์—… ๋ฐ ์„นํ„ฐ ์ •๋ณด๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
"""
industry_data = {}
# print("์‚ฐ์—… ๋ถ„๋ฅ˜ ์ •๋ณด ์ˆ˜์ง‘ ์ค‘...")
for ticker in tickers:
try:
# ์ข…๋ชฉ ์ •๋ณด ๊ฐ€์ ธ์˜ค๊ธฐ
stock = yf.Ticker(ticker)
info = stock.info
# ์‚ฐ์—… ๋ฐ ์„นํ„ฐ ์ •๋ณด ์ถ”์ถœ
industry_data[ticker] = {
'sector': info.get('sector', 'Unknown'),
'industry': info.get('industry', 'Unknown')
}
print(f"{ticker}: {industry_data[ticker]['sector']} - {industry_data[ticker]['industry']}")
except Exception as e:
print(f"{ticker} ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ ์‹คํŒจ: {str(e)}")
industry_data[ticker] = {'sector': 'Unknown', 'industry': 'Unknown'}
# ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์œผ๋กœ ๋ณ€ํ™˜
df = pd.DataFrame.from_dict(industry_data, orient='index')
df.index.name = 'ticker'
df.reset_index(inplace=True)
return df
def add_industry_encoding(data, industry_data):
"""
๋ฐ์ดํ„ฐ์— ์‚ฐ์—… ์ •๋ณด๋ฅผ ์ถ”๊ฐ€ํ•˜๊ณ  ์ธ์ฝ”๋”ฉํ•ฉ๋‹ˆ๋‹ค.
"""
print("\n์‚ฐ์—… ์ •๋ณด ์ธ์ฝ”๋”ฉ ์ค‘...")
# ์‚ฐ์—… ์ •๋ณด ๋ณ‘ํ•ฉ
data = data.merge(industry_data, on='ticker', how='left')
# ๋ˆ„๋ฝ๋œ ์‚ฐ์—… ์ •๋ณด๋Š” 'Unknown'์œผ๋กœ ์„ค์ •
data['sector'] = data['sector'].fillna('Unknown')
data['industry'] = data['industry'].fillna('Unknown')
# ์ธ์ฝ”๋”ฉ
sector_encoder = LabelEncoder()
industry_encoder = LabelEncoder()
data['sector_id'] = sector_encoder.fit_transform(data['sector'])
data['industry_id'] = industry_encoder.fit_transform(data['industry'])
# ์ธ์ฝ”๋” ์ •๋ณด ๋ฐ˜ํ™˜
encoders = {
'sector_encoder': sector_encoder,
'industry_encoder': industry_encoder,
'n_sectors': len(sector_encoder.classes_),
'n_industries': len(industry_encoder.classes_)
}
print(f"์„นํ„ฐ ์ˆ˜: {encoders['n_sectors']}, ์‚ฐ์—… ์ˆ˜: {encoders['n_industries']}")
return data, encoders
def create_sector_industry_mapping(ticker_list, sector_industry_df):
"""
์ข…๋ชฉ๋ณ„ ์„นํ„ฐ/์‚ฐ์—… ๋งคํ•‘์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
"""
sector_mapping = {}
industry_mapping = {}
for ticker in ticker_list:
if ticker in sector_industry_df.index:
sector_mapping[ticker] = int(sector_industry_df.loc[ticker, 'sector_id'])
industry_mapping[ticker] = int(sector_industry_df.loc[ticker, 'industry_id'])
else:
# ์ •๋ณด๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๊ธฐ๋ณธ๊ฐ’ 0 ์‚ฌ์šฉ
sector_mapping[ticker] = 0
industry_mapping[ticker] = 0
return sector_mapping, industry_mapping
def apply_sector_industry_mapping(ticker_data, ticker_encoder, sector_mapping, industry_mapping):
"""
ํ‹ฐ์ปค ๋ฐ์ดํ„ฐ์— ์„นํ„ฐ/์‚ฐ์—… ๋งคํ•‘์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
"""
# ์ถœ๋ ฅ ๋ฐฐ์—ด ์ดˆ๊ธฐํ™”
sector_data = np.zeros_like(ticker_data)
industry_data = np.zeros_like(ticker_data)
# ํ‹ฐ์ปค ID -> ์‹ค์ œ ํ‹ฐ์ปค ๋งคํ•‘
id_to_ticker = {v: k for k, v in ticker_encoder.mapping.items()}
# ๊ฐ ์œ„์น˜์— ๋งคํ•‘ ์ ์šฉ
for i in range(ticker_data.shape[0]):
for j in range(ticker_data.shape[1]):
ticker_id = ticker_data[i, j]
if ticker_id in id_to_ticker:
ticker = id_to_ticker[ticker_id]
sector_data[i, j] = sector_mapping.get(ticker, 0)
industry_data[i, j] = industry_mapping.get(ticker, 0)
print(f"์ ์šฉ๋œ ์„นํ„ฐ ํด๋ž˜์Šค: {np.unique(sector_data)}")
print(f"์ ์šฉ๋œ ์‚ฐ์—… ํด๋ž˜์Šค: {np.unique(industry_data)}")
return sector_data, industry_data