Spaces:
Sleeping
Sleeping
wenbemi
commited on
Update chat_a.py
Browse files
chat_a.py
CHANGED
|
@@ -1,33 +1,30 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
# In[10]:
|
| 5 |
-
import os, io, pathlib
|
| 6 |
-
from huggingface_hub import hf_hub_download
|
| 7 |
import pandas as pd
|
|
|
|
| 8 |
import torch
|
| 9 |
-
from sentence_transformers import SentenceTransformer, util
|
| 10 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 11 |
import torch.nn.functional as F
|
| 12 |
from collections import defaultdict
|
| 13 |
from datetime import datetime
|
| 14 |
-
import
|
| 15 |
-
import
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
|
|
|
| 18 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 19 |
|
| 20 |
-
HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "emisdfde/moai-travel-data")
|
| 21 |
HF_DATASET_REV = os.getenv("HF_DATASET_REV", "main")
|
| 22 |
|
| 23 |
def _is_pointer_bytes(b: bytes) -> bool:
|
| 24 |
head = b[:2048].decode(errors="ignore").lower()
|
| 25 |
-
# git-lfs / xet ν¬μΈν° ν
μ€νΈ ν¨ν΄ λͺ¨λ κ°μ§
|
| 26 |
return (
|
| 27 |
-
"version https://git-lfs.github.com/spec/v1" in head
|
| 28 |
-
"git-lfs" in head
|
| 29 |
-
"xet" in head
|
| 30 |
-
"pointer size" in head
|
| 31 |
)
|
| 32 |
|
| 33 |
def _read_csv_bytes(b: bytes) -> pd.DataFrame:
|
|
@@ -40,7 +37,7 @@ def load_csv_smart(local_path: str,
|
|
| 40 |
hub_filename: str | None = None,
|
| 41 |
repo_id: str = HF_DATASET_REPO,
|
| 42 |
repo_type: str = "dataset",
|
| 43 |
-
revision: str = HF_DATASET_REV):
|
| 44 |
if hub_filename is None:
|
| 45 |
hub_filename = os.path.basename(local_path)
|
| 46 |
if os.path.exists(local_path):
|
|
@@ -58,56 +55,69 @@ def load_csv_smart(local_path: str,
|
|
| 58 |
except UnicodeDecodeError:
|
| 59 |
return pd.read_csv(cached, encoding="cp949")
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
@st.cache_resource(show_spinner=False)
|
| 67 |
-
def
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
model = AutoModelForSequenceClassification.from_pretrained(repo, cache_dir=CACHE_DIR)
|
| 71 |
-
return tok, model
|
| 72 |
-
|
| 73 |
-
@st.cache_resource
|
| 74 |
-
def load_sbert_model():
|
| 75 |
-
print("SBERT λͺ¨λΈ λ‘λ© μ€... (μ΄ λ©μμ§λ ν λ²λ§ 보μ¬μΌ ν©λλ€)")
|
| 76 |
-
return SentenceTransformer("jhgan/ko-sroberta-multitask")
|
| 77 |
|
| 78 |
-
@st.cache_resource
|
| 79 |
def load_sentiment_model():
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
model.eval()
|
| 83 |
return model
|
| 84 |
|
| 85 |
-
@st.cache_resource
|
| 86 |
-
def
|
| 87 |
-
|
| 88 |
-
return
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
@st.cache_data(show_spinner=False)
|
| 92 |
-
def load_csv_any(p):
|
| 93 |
-
return pd.read_csv(p) if str(p).startswith(("http://","https://")) else pd.read_csv(p)
|
| 94 |
-
|
| 95 |
-
# trip_url = st.secrets.get("TRIPDATA_URL")
|
| 96 |
-
# if not trip_url:
|
| 97 |
-
# st.error("TRIPDATA_URL λ―Έμ€μ : Streamlit Secretsμ URLμ λ£μ΄μ£ΌμΈμ.")
|
| 98 |
-
# st.stop()
|
| 99 |
-
|
| 100 |
-
travel_df = load_csv_smart("trip_emotions.csv")
|
| 101 |
-
festival_df = load_csv_smart("festivals.csv")
|
| 102 |
-
external_score_df = load_csv_smart("external_scores.csv")
|
| 103 |
-
external_score_df.columns = external_score_df.columns.str.strip()
|
| 104 |
-
weather_df = load_csv_smart("weather.csv")
|
| 105 |
-
package_df = load_csv_smart("packages.csv")
|
| 106 |
-
package_df.columns = package_df.columns.str.strip()
|
| 107 |
-
master_df = load_csv_smart("countries_cities.csv")
|
| 108 |
-
|
| 109 |
-
countries = travel_df["μ¬νλλΌ"].dropna().unique().tolist()
|
| 110 |
-
cities = travel_df["μ¬νλμ"].dropna().unique().tolist()
|
| 111 |
|
| 112 |
def detect_location_filter(text, intent_score=None):
|
| 113 |
def in_text_exact(word):
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import os, io, json, pathlib, re, random
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import pandas as pd
|
| 4 |
+
import streamlit as st
|
| 5 |
import torch
|
|
|
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from collections import defaultdict
|
| 8 |
from datetime import datetime
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
from sentence_transformers import SentenceTransformer, util
|
| 11 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 12 |
+
from css import log_and_render
|
| 13 |
|
| 14 |
+
# ββββββββββββββββββββββββββββββββ μΊμ/λ°μ΄ν°μ
μ€μ ββββββββββββββββββββββββββββββββ
|
| 15 |
+
CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/data/hf-cache")
|
| 16 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 17 |
|
| 18 |
+
HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "emisdfde/moai-travel-data")
|
| 19 |
HF_DATASET_REV = os.getenv("HF_DATASET_REV", "main")
|
| 20 |
|
| 21 |
def _is_pointer_bytes(b: bytes) -> bool:
|
| 22 |
head = b[:2048].decode(errors="ignore").lower()
|
|
|
|
| 23 |
return (
|
| 24 |
+
"version https://git-lfs.github.com/spec/v1" in head
|
| 25 |
+
or "git-lfs" in head
|
| 26 |
+
or "xet" in head
|
| 27 |
+
or "pointer size" in head
|
| 28 |
)
|
| 29 |
|
| 30 |
def _read_csv_bytes(b: bytes) -> pd.DataFrame:
|
|
|
|
| 37 |
hub_filename: str | None = None,
|
| 38 |
repo_id: str = HF_DATASET_REPO,
|
| 39 |
repo_type: str = "dataset",
|
| 40 |
+
revision: str = HF_DATASET_REV) -> pd.DataFrame:
|
| 41 |
if hub_filename is None:
|
| 42 |
hub_filename = os.path.basename(local_path)
|
| 43 |
if os.path.exists(local_path):
|
|
|
|
| 55 |
except UnicodeDecodeError:
|
| 56 |
return pd.read_csv(cached, encoding="cp949")
|
| 57 |
|
| 58 |
+
# ββββββββββββββββββββββββββββββββ μ μ λ°μ΄ν° 컨ν
μ΄λ (μ§μ° μ΄κΈ°ν) ββββββββββββββββββββββββββββββββ
|
| 59 |
+
travel_df = festival_df = external_score_df = weather_df = package_df = master_df = None
|
| 60 |
+
countries, cities = [], []
|
| 61 |
+
theme_title_phrases = {}
|
| 62 |
|
| 63 |
+
def _strip_columns(df: pd.DataFrame | None) -> pd.DataFrame | None:
|
| 64 |
+
if df is not None and hasattr(df, "columns"):
|
| 65 |
+
df.columns = df.columns.str.strip()
|
| 66 |
+
return df
|
| 67 |
+
|
| 68 |
+
def init_datasets(*,
|
| 69 |
+
travel_df: pd.DataFrame,
|
| 70 |
+
festival_df: pd.DataFrame,
|
| 71 |
+
external_score_df: pd.DataFrame,
|
| 72 |
+
weather_df: pd.DataFrame,
|
| 73 |
+
package_df: pd.DataFrame,
|
| 74 |
+
master_df: pd.DataFrame,
|
| 75 |
+
theme_title_phrases: dict | None = None):
|
| 76 |
+
"""app.pyμμ λ°μ΄ν° λ‘λκ° λλ λ€ λ± ν λ² νΈμΆ"""
|
| 77 |
+
globals()["travel_df"] = _strip_columns(travel_df.copy())
|
| 78 |
+
globals()["festival_df"] = _strip_columns(festival_df.copy())
|
| 79 |
+
globals()["external_score_df"] = _strip_columns(external_score_df.copy())
|
| 80 |
+
globals()["weather_df"] = _strip_columns(weather_df.copy())
|
| 81 |
+
globals()["package_df"] = _strip_columns(package_df.copy())
|
| 82 |
+
globals()["master_df"] = _strip_columns(master_df.copy())
|
| 83 |
+
if theme_title_phrases is not None:
|
| 84 |
+
globals()["theme_title_phrases"] = theme_title_phrases
|
| 85 |
+
|
| 86 |
+
# νμ μ»¬λΌ νμΈ
|
| 87 |
+
req = ["μ¬νλλΌ", "μ¬νλμ", "μ¬νμ§"]
|
| 88 |
+
miss = [c for c in req if c not in globals()["travel_df"].columns]
|
| 89 |
+
if miss:
|
| 90 |
+
raise KeyError(f"travel_df νμ μ»¬λΌ λλ½: {miss} / μ€μ : {list(globals()['travel_df'].columns)}")
|
| 91 |
+
|
| 92 |
+
# νμ λͺ©λ‘
|
| 93 |
+
global countries, cities
|
| 94 |
+
countries = sorted(globals()["travel_df"]["μ¬νλλΌ"].dropna().unique().tolist())
|
| 95 |
+
cities = sorted(globals()["travel_df"]["μ¬νλμ"].dropna().unique().tolist())
|
| 96 |
+
|
| 97 |
+
def _assert_ready():
|
| 98 |
+
if globals()["travel_df"] is None:
|
| 99 |
+
raise RuntimeError("chat_a.init_datasets(...)λ₯Ό λ¨Όμ νΈμΆν΄μ£ΌμΈμ.")
|
| 100 |
+
|
| 101 |
+
# ββββββββββββββββββββββββββββββββ λͺ¨λΈ λ‘λ (μΊμ/κΆν μμ ) ββββββββββββββββββββββββββββββββ
|
| 102 |
@st.cache_resource(show_spinner=False)
|
| 103 |
+
def load_tokenizer():
|
| 104 |
+
return AutoTokenizer.from_pretrained("hun3359/klue-bert-base-sentiment",
|
| 105 |
+
cache_dir=CACHE_DIR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
@st.cache_resource(show_spinner=False)
|
| 108 |
def load_sentiment_model():
|
| 109 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 110 |
+
"hun3359/klue-bert-base-sentiment",
|
| 111 |
+
cache_dir=CACHE_DIR
|
| 112 |
+
)
|
| 113 |
model.eval()
|
| 114 |
return model
|
| 115 |
|
| 116 |
+
@st.cache_resource(show_spinner=False)
|
| 117 |
+
def load_sbert_model():
|
| 118 |
+
# SentenceTransformer μͺ½λ μΊμ ν΄λ λͺ
μ
|
| 119 |
+
return SentenceTransformer("jhgan/ko-sroberta-multitask",
|
| 120 |
+
cache_folder=CACHE_DIR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
def detect_location_filter(text, intent_score=None):
|
| 123 |
def in_text_exact(word):
|