import argparse import csv import gradio as gr import json import os import pandas as pd import requests from datetime import datetime from io import StringIO region_choices = ['us', 'ko'] user_portfolio_output_path = "used_user.csv" json_example_output_path = "output_example.json" json_new_output_path = "output.json" report_example_output_path = "report_example.txt" report_new_output_path = "report.txt" def make_user_csv(csv_text, top_n): """ 사용자 포트폴리오를 top_n만 user_portfolio_output_path에 저장""" df = load_portfolio_to_df(csv_text) df['매입주가'] = df['매입주가'].astype(float) df['보유주식수'] = df['보유주식수'].astype(float) df['매입금액'] = df['매입주가'] * df['보유주식수'] df_sorted = df.sort_values(by='매입금액', ascending=False) df_top = df_sorted.head(top_n) df_top = df_top[['종목코드', '종목명', '매입주가', '보유주식수']] df_top.to_csv(user_portfolio_output_path, index=False, encoding='utf-8-sig') return df_top # 반환값 DataFrame def load_portfolio_to_df(csv_text): """csv_text를 pandas DataFrame으로 변환하여 반환""" if isinstance(csv_text, str) and os.path.exists(csv_text): return pd.read_csv(csv_text, dtype=str) if isinstance(csv_text, str): try: return pd.read_csv(StringIO(csv_text), dtype=str) except Exception as e: raise ValueError(f"CSV 문자열을 DataFrame으로 변환 실패: {e}") if isinstance(csv_text, pd.DataFrame): return csv_text.copy() raise ValueError("csv_text는 CSV 파일 경로, CSV 문자열 또는 DataFrame이어야 합니다.") def parse_csv(csv_text): """CSV 텍스트 → 종목코드 리스트""" try: df = pd.read_csv(StringIO(csv_text)) except Exception: return {"error": "csv_parse_failed", "detail": "CSV 형식이 잘못되었습니다."} if "종목코드" not in df.columns: return {"error": "missing_column", "detail": "'종목코드' 컬럼이 없습니다."} stock_list = df["종목코드"].astype(str).tolist() return stock_list def request_post(base_url: str, service: str, req_body: dict): try: url = f'{base_url}/{service}' res = requests.post( url=url, json=req_body ) res.raise_for_status() except Exception as e: print(f'[ERROR] {service} / 에러 발생: {str(e)}') res_body = res.json() if res_body['success']: return res_body['data'] def remove_duplicates(data_list): """stock 키값 기준으로 중복 제거""" seen = set() unique_list = [] for item in data_list: stock_name = item.get('stock') if stock_name and stock_name not in seen: seen.add(stock_name) unique_list.append(item) return unique_list def fetch_example_data(): """json_example_output_path 경로의 JSON 파일을 읽어 dict 형태로 반환""" if not os.path.exists(json_example_output_path): print(f"[ERROR] 파일이 존재하지 않습니다: {json_example_output_path}") return None try: user_name = "이서준" region = 'us' user_csv = """종목코드,종목명,매입주가,보유주식수 NVDA,NVIDIA Corporation,159.69,167 AMZN,"Amazon.com, Inc.",211.58,140 MSFT,Microsoft Corporation,527.99,131 AAPL,Apple Inc.,258.45,116 GOOGL,Alphabet Inc.,262.59,97 """ processed_csv = make_user_csv(user_csv, top_n=5) processed_csv.to_csv(user_portfolio_output_path, index=False, encoding="utf-8") with open(user_portfolio_output_path, "r", encoding="utf-8") as f: csv_data = f.read() with open(json_example_output_path, "r", encoding="utf-8") as f: json_data = json.load(f) return user_name, region, csv_data, json_data except Exception as e: print(f"[ERROR] process_user_data 오류: {e}") return "", [], "", {} def update_user_with_stock_price(json_data, res_data): """ res_data 로 받은 stock_price 정보를 기반으로 json_data["user"] 항목에 현재주가, 수익률, 현재자산을 업데이트 """ def normalize(code): return str(code).strip().upper() # user 목록 가져오기 user_list = json_data.get("user", []) if not isinstance(user_list, list): print("[WARN] user 데이터 구조가 list 아님 → 업데이트 불가") return # 종목 가격 dict(normalized) stock_price_map = {normalize(k): v for k, v in res_data.items()} # user 각 항목 업데이트 for user_item in user_list: raw_code = user_item.get("종목코드") if not raw_code: continue code = normalize(raw_code) # stock_price 에 해당 종목이 존재하는지 확인 prices = stock_price_map.get(code) if not prices: print(f"[WARN] {code} 종목의 주가 데이터 없음 → user 추가정보 스킵") continue # 최신 종가 (list 마지막) last_row = prices[-1] try: current_price = float(last_row.get("Close")) buy_price = float(user_item.get("매입주가")) qty = float(user_item.get("보유주식수")) except Exception: print(f"[WARN] {code} 값 변환 오류 → user 추가정보 스킵") continue # 수익률 계산 profit_rate = round(((current_price - buy_price) / buy_price) * 100, 2) current_asset = round(current_price * qty, 2) # 업데이트 user_item["현재주가"] = str(current_price) user_item["수익률"] = f"{profit_rate}%" user_item["현재자산"] = str(current_asset) # 저장 json_data["user"] = user_list return json_data def fetch_new_data(region, csv_text, server_url): """API 실행: 유사 투자자 데이터 요청 후 JSON 반환""" try: final_json = {} portfolio = make_user_csv(csv_text, top_n=5) # user_portfolio_output_path 저장 # =========================== # 1. 유사 투자회사 API # =========================== service = 'similar_investors' req_body = { 'csv_text': portfolio.to_csv(index=False), 'region': region[0] if isinstance(region, list) else region, 'top': 5 } print(f"[INFO] 유사 투자자 {req_body['top']}개 수집 시작..") res_data = request_post(server_url, service, req_body) if not res_data: print(f'[ERROR] 유사 투자자 수집 실패..') return None print(f'[INFO] 유사 투자자 수집 완료:', res_data) final_json[service] = res_data # user 키 저장 used_user_csv = portfolio.to_csv(index=False) f = StringIO(used_user_csv) reader = csv.reader(f) rows = list(reader) if not rows or len(rows) < 2: print("[WARN] portfolio CSV 내용이 비어 있어 user 데이터 저장을 생략합니다.") final_json["user"] = [] else: headers = rows[0] user_list = [dict(zip(headers, cols)) for cols in rows[1:]] final_json["user"] = user_list # 산업, 테마 요청시 투자자 종목도 포함 si_names = list(set([d['NAME'] for rows in res_data.values() for d in rows])) si_companies = list(set([d['COMPANY'] for rows in res_data.values() for d in rows])) # =========================== # 2. 유사 투자회사 설명 API # =========================== print("[INFO] 투자회사 name 목록:", si_names) final_json["investment_company"] = {} service = "investment_company" for name in si_names: print(f"[INFO] {name} 설명 수집 시작..") req_body = {"name": name} res_data = request_post(server_url, service, req_body) if not res_data: print("[ERROR] 설명 수집 실패:", name) final_json[service][name] = { "error": True, "detail": "request_post failed" } continue print("[INFO] 설명 수집 완료:", res_data) final_json[service][name] = res_data if not final_json[service]: print(f'[ERROR] 유사 투자자 수집 실패..') return None # =========================== # 3. 산업군 API # =========================== service = 'industry_info' stock = portfolio['종목코드'].to_list() + si_companies req_body = { 'stock': stock } print(f'[INFO] 산업정보 {",".join(stock)} 수집 시작..') res_data = request_post(server_url, service, req_body) if not res_data: print(f'[ERROR] 산업정보 수집 실패..') return None print(f'[INFO] 산업정보 수집 완료:', res_data) res_data = remove_duplicates(res_data) final_json[service] = res_data # =========================== # 4. 테마 API # =========================== service = 'theme_info' stock = portfolio['종목코드'].to_list() + si_companies req_body = { 'stock': stock } print(f'[INFO] 테마정보 {",".join(stock)} 수집 시작..') res_data = request_post(server_url, service, req_body) if not res_data: print(f'[ERROR] 테마정보 수집 실패..') return None print(f'[INFO] 테마정보 수집 완료:', res_data) res_data = remove_duplicates(res_data) final_json[service] = res_data # =========================== # 5. 주가 API # =========================== service = 'stock_price' stock = ','.join(portfolio['종목코드']) req_body = {'stock': stock} print(f'[INFO] 종목가격 {stock} 수집 시작..') res_data = request_post(server_url, service, req_body) if not res_data: print(f'[ERROR] 종목가격 수집 실패..') return None print(f'[INFO] 종목가격 수집 완료:', res_data) final_json[service] = res_data update_user_with_stock_price(final_json, res_data) # =========================== # 6. 뉴스 API # =========================== service = 'stock_news' for stock in portfolio['종목코드']: print(f'[INFO] 종목뉴스 {stock} 수집 시작..') req_body = { 'stock': stock, 'period': 7 } res_data = request_post(server_url, service, req_body) if res_data: print(f'[INFO] 종목뉴스 {stock} 수집 완료:', res_data) final_json.setdefault(service, {}) if isinstance(res_data, dict): final_json[service].update(res_data) else: final_json[service][stock] = res_data else: print(f'[ERROR] 종목뉴스 {stock} 수집 실패..') if not final_json.get(service): print(f'[ERROR] 종목뉴스 수집 실패..') return None try: with open(json_new_output_path, "w", encoding="utf-8") as f: json.dump(final_json, f, ensure_ascii=False, indent=4) print(f"[INFO] {json_new_output_path}에 저장 완료") except Exception as e: print(f"[ERROR] {json_new_output_path} 저장 오류:", str(e)) return final_json except Exception as e: print("[ERROR] 실행 중 오류:", str(e)) return {"error": "exception", "detail": str(e)} def load_report(): """report_example_output_path 경로의 txt 파일을 읽어 문자열로 반환""" if not os.path.exists(report_example_output_path): print(f"[ERROR] 파일이 존재하지 않습니다: {report_example_output_path}") return "" try: with open(report_example_output_path, "r", encoding="utf-8") as f: data = f.read() return data except Exception as e: print(f"[ERROR] 파일 읽기 중 오류 발생: {e}") return "" def generate_report(user_name, server_url, json): """최종 리포트 생성 API 호출 → user_name, csv, json 입력""" try: print(f"[INFO] {user_name} 님을 위한 포트폴리오 분석 리포트 생성") with open(user_portfolio_output_path, "r", encoding="utf-8") as f: csv_text = f.read() payload = { "date": datetime.now().strftime("%Y-%m-%d"), # 리포트 생성 날짜 (기본: 오늘 날짜) "user_name": user_name, "csv_text": csv_text, "json_text": json } print("[INFO] 입력데이터 로드 완료") print("[INFO] 리포트 생성 중 ... ") resp = requests.post(f"{server_url}/report", json=payload) if resp.status_code != 200: return f"❌ 리포트 생성 실패: {resp.text}" data = resp.json() report_md = data["data"]["report"] print("[INFO] 리포트 생성 완료") with open(report_new_output_path, "w", encoding="utf-8") as f: f.write(report_md) print(f"[INFO] {report_new_output_path}에 저장 완료\n") return report_md except FileNotFoundError: return f"❌ 오류: {user_portfolio_output_path} 파일을 찾을 수 없습니다." except Exception as e: return f"❌ 오류 발생: {str(e)}" def enforce_single(selection): """ selection: 현재 체크된 리스트 (None or list) 반환값: 반드시 1개만 담긴 리스트 """ # 사용자 입력이 None 이거나 비어있으면 기본값 반환 if not selection: return [region_choices[0]] # 여러개 선택된 경우 마지막 선택한 항목만 남김 if len(selection) > 1: return [selection[-1]] # 이미 1개면 그대로 반환 return selection def main(): parser = argparse.ArgumentParser(description="Web Client") parser.add_argument("--port", default=7864, type=int, help="Port number to run the Gradio app") parser.add_argument("--server-url", default="http://localhost:8080" , type=str, help="Server url to run the app") args = parser.parse_args() with gr.Blocks() as app: gr.Markdown("# 포트폴리오 분석 리포트 📈 ") server_url = gr.State(args.server_url) with gr.Row(scale=1): with gr.Column(scale=1): gr.Markdown("### 1. 사용자 이름을 입력하세요.") user_name = gr.Textbox(value="이서준", placeholder="예: 이서준 ", show_label=False) with gr.Column(scale=1): gr.Markdown("") with gr.Row(scale=1): with gr.Column(scale=1): gr.Markdown("### 2. 사용자 포트폴리오를 입력하세요.") region_checkbox = gr.CheckboxGroup(choices=region_choices, value=[region_choices[0]], label="당신의 포트폴리오에 가장 부합하는 지역을 선택하세요.") region_checkbox.change(fn=enforce_single, inputs=region_checkbox, outputs=region_checkbox) csv_input = gr.Textbox( label="당신을 포트폴리오를 입력하세요. (CSV형식) ", lines=7, max_lines=7, value="종목코드,종목명,매입주가,보유주식수\n" "TSLA,Tesla Inc.,227.34,88\n" "META,Meta Platforms Inc.,484.12,52\n" "NFLX,Netflix Inc.,598.43,31\n" "AVGO,Broadcom Inc.,1421.77,12\n" "CRM,Salesforce Inc.,289.24,63", placeholder="예: \n" "종목코드,종목명,매입주가,보유주식수\n" "TSLA,Tesla Inc.,227.34,88\n" "META,Meta Platforms Inc.,484.12,52\n" "NFLX,Netflix Inc.,598.43,31\n" "AVGO,Broadcom Inc.,1421.77,12\n" "CRM,Salesforce Inc.,289.24,63", interactive=True, show_label=True ) with gr.Column(scale=1): gr.Markdown("") with gr.Row(scale=1): with gr.Column(scale=1): gr.Markdown("### 3. 입력 데이터를 가져오세요.") with gr.Row(scale=1): fetch_example_button = gr.Button("기존 입력 데이터 가져오기") fetch_new_button = gr.Button("새로운 입력 데이터 생성하기") with gr.Column(scale=1): gr.Markdown("### 4. 포트폴리오 분석 리포트를 생성하세요.") with gr.Row(scale=1): report_example_button = gr.Button("기존 리포트 가져오기") report_new_button = gr.Button("새로운 리포트 생성하기") with gr.Row(scale=1): json_output = gr.JSON(label="API 응답(JSON)") report_output = gr.Markdown() fetch_example_button.click( fn=fetch_example_data, inputs=[], outputs=[user_name, region_checkbox, csv_input, json_output] ) fetch_new_button.click( fn=fetch_new_data, inputs=[region_checkbox, csv_input, server_url], outputs=[json_output] ) report_example_button.click( fn=load_report, inputs=[], outputs=[report_output] ) report_new_button.click( fn=generate_report, inputs=[user_name, server_url, json_output], outputs=[report_output] ) app.launch(theme='CultriX/gradio-theme', share=True, server_name="0.0.0.0", server_port=args.port, debug=True) if __name__ == "__main__": main()