wenbemi commited on
Commit
d8952e9
Β·
verified Β·
1 Parent(s): a6c552a

Update chat_a.py

Browse files
Files changed (1) hide show
  1. chat_a.py +71 -61
chat_a.py CHANGED
@@ -1,33 +1,30 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- # In[10]:
5
- import os, io, pathlib
6
- from huggingface_hub import hf_hub_download
7
  import pandas as pd
 
8
  import torch
9
- from sentence_transformers import SentenceTransformer, util
10
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
  import torch.nn.functional as F
12
  from collections import defaultdict
13
  from datetime import datetime
14
- import random
15
- import re
 
 
16
 
17
- CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/tmp/hf-cache")
 
18
  os.makedirs(CACHE_DIR, exist_ok=True)
19
 
20
- HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "emisdfde/moai-travel-data") # ← 본인 리포
21
  HF_DATASET_REV = os.getenv("HF_DATASET_REV", "main")
22
 
23
  def _is_pointer_bytes(b: bytes) -> bool:
24
  head = b[:2048].decode(errors="ignore").lower()
25
- # git-lfs / xet 포인터 ν…μŠ€νŠΈ νŒ¨ν„΄ λͺ¨λ‘ 감지
26
  return (
27
- "version https://git-lfs.github.com/spec/v1" in head or
28
- "git-lfs" in head or
29
- "xet" in head or # e.g. "Xet backed hash"
30
- "pointer size" in head
31
  )
32
 
33
  def _read_csv_bytes(b: bytes) -> pd.DataFrame:
@@ -40,7 +37,7 @@ def load_csv_smart(local_path: str,
40
  hub_filename: str | None = None,
41
  repo_id: str = HF_DATASET_REPO,
42
  repo_type: str = "dataset",
43
- revision: str = HF_DATASET_REV):
44
  if hub_filename is None:
45
  hub_filename = os.path.basename(local_path)
46
  if os.path.exists(local_path):
@@ -58,56 +55,69 @@ def load_csv_smart(local_path: str,
58
  except UnicodeDecodeError:
59
  return pd.read_csv(cached, encoding="cp949")
60
 
 
 
 
 
61
 
62
- from css import log_and_render
63
- import streamlit as st, pandas as pd, json, requests
64
- # -------------------- λͺ¨λΈ 및 데이터 λ‘œλ”© --------------------
65
- # λͺ¨λΈ λ‘œλ”© 뢀뢄을 ν•¨μˆ˜λ‘œ λ§Œλ“€κ³  λ°μ½”λ ˆμ΄ν„° μΆ”κ°€
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  @st.cache_resource(show_spinner=False)
67
- def load_sentiment_model():
68
- repo = "hun3359/klue-bert-base-sentiment"
69
- tok = AutoTokenizer.from_pretrained(repo, cache_dir=CACHE_DIR)
70
- model = AutoModelForSequenceClassification.from_pretrained(repo, cache_dir=CACHE_DIR)
71
- return tok, model
72
-
73
- @st.cache_resource
74
- def load_sbert_model():
75
- print("SBERT λͺ¨λΈ λ‘œλ”© 쀑... (이 λ©”μ‹œμ§€λŠ” ν•œ 번만 보여야 ν•©λ‹ˆλ‹€)")
76
- return SentenceTransformer("jhgan/ko-sroberta-multitask")
77
 
78
- @st.cache_resource
79
  def load_sentiment_model():
80
- print("감성 뢄석 λͺ¨λΈ λ‘œλ”© 쀑... (이 λ©”μ‹œμ§€λŠ” ν•œ 번만 보여야 ν•©λ‹ˆλ‹€)")
81
- model = AutoModelForSequenceClassification.from_pretrained("hun3359/klue-bert-base-sentiment")
 
 
82
  model.eval()
83
  return model
84
 
85
- @st.cache_resource
86
- def load_tokenizer():
87
- print("ν† ν¬λ‚˜μ΄μ € λ‘œλ”© 쀑... (이 λ©”μ‹œμ§€λŠ” ν•œ 번만 보여야 ν•©λ‹ˆλ‹€)")
88
- return AutoTokenizer.from_pretrained("hun3359/klue-bert-base-sentiment")
89
-
90
-
91
- @st.cache_data(show_spinner=False)
92
- def load_csv_any(p):
93
- return pd.read_csv(p) if str(p).startswith(("http://","https://")) else pd.read_csv(p)
94
-
95
- # trip_url = st.secrets.get("TRIPDATA_URL")
96
- # if not trip_url:
97
- # st.error("TRIPDATA_URL λ―Έμ„€μ •: Streamlit Secrets에 URL을 λ„£μ–΄μ£Όμ„Έμš”.")
98
- # st.stop()
99
-
100
- travel_df = load_csv_smart("trip_emotions.csv")
101
- festival_df = load_csv_smart("festivals.csv")
102
- external_score_df = load_csv_smart("external_scores.csv")
103
- external_score_df.columns = external_score_df.columns.str.strip()
104
- weather_df = load_csv_smart("weather.csv")
105
- package_df = load_csv_smart("packages.csv")
106
- package_df.columns = package_df.columns.str.strip()
107
- master_df = load_csv_smart("countries_cities.csv")
108
-
109
- countries = travel_df["μ—¬ν–‰λ‚˜λΌ"].dropna().unique().tolist()
110
- cities = travel_df["μ—¬ν–‰λ„μ‹œ"].dropna().unique().tolist()
111
 
112
  def detect_location_filter(text, intent_score=None):
113
  def in_text_exact(word):
 
1
+ # -*- coding: utf-8 -*-
2
+ import os, io, json, pathlib, re, random
 
 
 
 
3
  import pandas as pd
4
+ import streamlit as st
5
  import torch
 
 
6
  import torch.nn.functional as F
7
  from collections import defaultdict
8
  from datetime import datetime
9
+ from huggingface_hub import hf_hub_download
10
+ from sentence_transformers import SentenceTransformer, util
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+ from css import log_and_render
13
 
14
+ # ──────────────────────────────── μΊμ‹œ/데이터셋 μ„€μ • ────────────────────────────────
15
+ CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/data/hf-cache")
16
  os.makedirs(CACHE_DIR, exist_ok=True)
17
 
18
+ HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "emisdfde/moai-travel-data")
19
  HF_DATASET_REV = os.getenv("HF_DATASET_REV", "main")
20
 
21
  def _is_pointer_bytes(b: bytes) -> bool:
22
  head = b[:2048].decode(errors="ignore").lower()
 
23
  return (
24
+ "version https://git-lfs.github.com/spec/v1" in head
25
+ or "git-lfs" in head
26
+ or "xet" in head
27
+ or "pointer size" in head
28
  )
29
 
30
  def _read_csv_bytes(b: bytes) -> pd.DataFrame:
 
37
  hub_filename: str | None = None,
38
  repo_id: str = HF_DATASET_REPO,
39
  repo_type: str = "dataset",
40
+ revision: str = HF_DATASET_REV) -> pd.DataFrame:
41
  if hub_filename is None:
42
  hub_filename = os.path.basename(local_path)
43
  if os.path.exists(local_path):
 
55
  except UnicodeDecodeError:
56
  return pd.read_csv(cached, encoding="cp949")
57
 
58
+ # ──────────────────────────────── μ „μ—­ 데이터 μ»¨ν…Œμ΄λ„ˆ (μ§€μ—° μ΄ˆκΈ°ν™”) ────────────────────────────────
59
+ travel_df = festival_df = external_score_df = weather_df = package_df = master_df = None
60
+ countries, cities = [], []
61
+ theme_title_phrases = {}
62
 
63
+ def _strip_columns(df: pd.DataFrame | None) -> pd.DataFrame | None:
64
+ if df is not None and hasattr(df, "columns"):
65
+ df.columns = df.columns.str.strip()
66
+ return df
67
+
68
+ def init_datasets(*,
69
+ travel_df: pd.DataFrame,
70
+ festival_df: pd.DataFrame,
71
+ external_score_df: pd.DataFrame,
72
+ weather_df: pd.DataFrame,
73
+ package_df: pd.DataFrame,
74
+ master_df: pd.DataFrame,
75
+ theme_title_phrases: dict | None = None):
76
+ """app.pyμ—μ„œ 데이터 λ‘œλ“œκ°€ λλ‚œ λ’€ λ”± ν•œ 번 호좜"""
77
+ globals()["travel_df"] = _strip_columns(travel_df.copy())
78
+ globals()["festival_df"] = _strip_columns(festival_df.copy())
79
+ globals()["external_score_df"] = _strip_columns(external_score_df.copy())
80
+ globals()["weather_df"] = _strip_columns(weather_df.copy())
81
+ globals()["package_df"] = _strip_columns(package_df.copy())
82
+ globals()["master_df"] = _strip_columns(master_df.copy())
83
+ if theme_title_phrases is not None:
84
+ globals()["theme_title_phrases"] = theme_title_phrases
85
+
86
+ # ν•„μˆ˜ 컬럼 확인
87
+ req = ["μ—¬ν–‰λ‚˜λΌ", "μ—¬ν–‰λ„μ‹œ", "μ—¬ν–‰μ§€"]
88
+ miss = [c for c in req if c not in globals()["travel_df"].columns]
89
+ if miss:
90
+ raise KeyError(f"travel_df ν•„μˆ˜ 컬럼 λˆ„λ½: {miss} / μ‹€μ œ: {list(globals()['travel_df'].columns)}")
91
+
92
+ # νŒŒμƒ λͺ©λ‘
93
+ global countries, cities
94
+ countries = sorted(globals()["travel_df"]["μ—¬ν–‰λ‚˜λΌ"].dropna().unique().tolist())
95
+ cities = sorted(globals()["travel_df"]["μ—¬ν–‰λ„μ‹œ"].dropna().unique().tolist())
96
+
97
+ def _assert_ready():
98
+ if globals()["travel_df"] is None:
99
+ raise RuntimeError("chat_a.init_datasets(...)λ₯Ό λ¨Όμ € ν˜ΈμΆœν•΄μ£Όμ„Έμš”.")
100
+
101
+ # ──────────────────────────────── λͺ¨λΈ λ‘œλ” (μΊμ‹œ/κΆŒν•œ μ•ˆμ „) ────────────────────────────────
102
  @st.cache_resource(show_spinner=False)
103
+ def load_tokenizer():
104
+ return AutoTokenizer.from_pretrained("hun3359/klue-bert-base-sentiment",
105
+ cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
106
 
107
+ @st.cache_resource(show_spinner=False)
108
  def load_sentiment_model():
109
+ model = AutoModelForSequenceClassification.from_pretrained(
110
+ "hun3359/klue-bert-base-sentiment",
111
+ cache_dir=CACHE_DIR
112
+ )
113
  model.eval()
114
  return model
115
 
116
+ @st.cache_resource(show_spinner=False)
117
+ def load_sbert_model():
118
+ # SentenceTransformer μͺ½λ„ μΊμ‹œ 폴더 λͺ…μ‹œ
119
+ return SentenceTransformer("jhgan/ko-sroberta-multitask",
120
+ cache_folder=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def detect_location_filter(text, intent_score=None):
123
  def in_text_exact(word):