yongyeol commited on
Commit
f010769
ยท
verified ยท
1 Parent(s): 8d2a103

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +38 -22
src/streamlit_app.py CHANGED
@@ -3,9 +3,10 @@ import json
3
  import requests
4
  import streamlit as st
5
  from datetime import datetime
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
7
 
8
- # โœ… ์•ˆ์ „ํ•œ ์บ์‹œ ๊ฒฝ๋กœ ์„ค์ • (์ตœ์ƒ๋‹จ ํ•„์ˆ˜)
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
  os.environ["HF_HOME"] = "/tmp/hf_cache"
11
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
@@ -15,23 +16,25 @@ st.set_page_config(page_title="ํ•™์‚ฌ์ผ์ • ์บ˜๋ฆฐ๋”", layout="centered")
15
  st.title("๐Ÿ“… ํ•™์‚ฌ์ผ์ • ์บ˜๋ฆฐ๋” + AI ์š”์•ฝ")
16
  st.markdown("NEIS API์—์„œ ํ•™์‚ฌ์ผ์ •์„ ๋ถˆ๋Ÿฌ์˜ค๊ณ  FullCalendar๋กœ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.")
17
 
18
- # โœ… ๋””๋ฒ„๊น… ์ถœ๋ ฅ
19
- token_present = os.environ.get("HUGGINGFACE_TOKEN") is not None
20
- st.write("๐Ÿ” ํ† ํฐ ์žˆ์Œ:", token_present)
21
- st.write("โœ… ์บ์‹œ ๊ฒฝ๋กœ:", os.environ.get("TRANSFORMERS_CACHE"))
22
-
23
- # โœ… Gemma ๋ชจ๋ธ ๋กœ๋”ฉ ํ•จ์ˆ˜
24
  @st.cache_resource
25
  def load_model():
26
  token = os.environ.get("HUGGINGFACE_TOKEN")
27
- model_id = "google/gemma-2-2b-it"
28
  cache_dir = "/tmp/hf_cache"
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token, cache_dir=cache_dir)
31
- model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, cache_dir=cache_dir)
32
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
33
 
34
- llm = load_model()
35
 
36
  # โœ… ํ•™๊ต ์ •๋ณด ๊ฐ€์ ธ์˜ค๊ธฐ
37
  def get_school_info(region_code, school_name, api_key):
@@ -41,7 +44,7 @@ def get_school_info(region_code, school_name, api_key):
41
  school = data.get("schoolInfo", [{}])[1].get("row", [{}])[0]
42
  return school.get("SD_SCHUL_CODE"), school.get("ATPT_OFCDC_SC_CODE")
43
 
44
- # โœ… ํ•™์‚ฌ์ผ์ • ๊ฐ€์ ธ์˜ค๊ธฐ (์›” ๋‹จ์œ„)
45
  def get_schedule(region_code, school_code, year, month, api_key):
46
  from_ymd = f"{year}{month:02}01"
47
  to_ymd = f"{year}{month:02}31"
@@ -49,13 +52,13 @@ def get_schedule(region_code, school_code, year, month, api_key):
49
  res = requests.get(url)
50
  data = res.json()
51
  rows = data.get("SchoolSchedule", [{}])[1].get("row", [])
52
- st.write("๐Ÿ“ฆ ๋ถˆ๋Ÿฌ์˜จ ์ผ์ • raw data:", rows)
53
  return rows
54
 
55
  # โœ… ์š”์•ฝ ์ƒ์„ฑ
56
  def summarize_schedule(rows, school_name, year):
57
  if not rows:
58
  return "์ผ์ •์ด ์—†์–ด ์š”์•ฝํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
 
59
  lines = []
60
  for row in rows:
61
  date = row["AA_YMD"]
@@ -63,11 +66,26 @@ def summarize_schedule(rows, school_name, year):
63
  event = row["EVENT_NM"]
64
  lines.append(f"{dt}: {event}")
65
  text = "\n".join(lines)
 
66
  prompt = f"{school_name}๊ฐ€ {year}๋…„๋„์— ๊ฐ€์ง€๋Š” ํ•™์‚ฌ์ผ์ •์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:\n{text}\n์ฃผ์š” ์ผ์ •์„ ์š”์•ฝํ•ด์ฃผ์„ธ์š”."
67
- st.write("๐Ÿ“ค ์š”์•ฝ์— ์ „๋‹ฌ๋œ ํ”„๋กฌํ”„ํŠธ:", prompt)
68
- result = llm([{"role": "user", "content": prompt}])
69
- st.write("๐Ÿ“ฅ ๋ชจ๋ธ ์ƒ์„ฑ ๊ฒฐ๊ณผ:", result)
70
- return result[0]["generated_text"].replace(prompt, "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # โœ… ์ง€์—ญ/ํ•™๊ต/๋…„๋„/์›” ์„ ํƒ UI
73
  region_options = {
@@ -83,6 +101,7 @@ with st.form("query_form"):
83
  month = st.selectbox("์›”", options=list(range(1, 13)), index=6)
84
  submitted = st.form_submit_button("๐Ÿ“… ํ•™์‚ฌ์ผ์ • ๋ถˆ๋Ÿฌ์˜ค๊ธฐ")
85
 
 
86
  if submitted:
87
  with st.spinner("์ผ์ • ๋ถˆ๋Ÿฌ์˜ค๋Š” ์ค‘..."):
88
  api_key = os.environ.get("NEIS_API_KEY", "a69e08342c8947b4a52cd72789a5ecaf")
@@ -94,7 +113,6 @@ if submitted:
94
  if not schedule_rows:
95
  st.info("ํ•ด๋‹น ์กฐ๊ฑด์˜ ํ•™์‚ฌ์ผ์ •์ด ์—†์Šต๋‹ˆ๋‹ค.")
96
  else:
97
- # โœ… ์ผ์ • ์ถœ๋ ฅ์šฉ FullCalendar ์ƒ์„ฑ
98
  events = [
99
  {
100
  "title": row["EVENT_NM"],
@@ -103,7 +121,6 @@ if submitted:
103
  for row in schedule_rows
104
  if "AA_YMD" in row and "EVENT_NM" in row
105
  ]
106
- st.write("๐Ÿ“… FullCalendar์— ์ „๋‹ฌํ•  events:", events)
107
  event_json = json.dumps(events, ensure_ascii=False)
108
 
109
  st.components.v1.html(f"""
@@ -130,10 +147,9 @@ if submitted:
130
  </html>
131
  """, height=650)
132
 
133
- # โœ… ์š”์•ฝ ์ƒ์„ฑ ๋ฒ„ํŠผ ์ถ”๊ฐ€
134
  with st.expander("โœจ 1๋…„์น˜ ์š”์•ฝ ๋ณด๊ธฐ", expanded=False):
135
  if st.button("๐Ÿค– ์š”์•ฝ ์ƒ์„ฑํ•˜๊ธฐ"):
136
- with st.spinner("Gemma ๋ชจ๋ธ์ด ์š”์•ฝ ์ค‘..."):
137
  summary = summarize_schedule(schedule_rows, school_name, year)
138
  st.success("์š”์•ฝ ์™„๋ฃŒ!")
139
  st.markdown(f"**{school_name} {year}๋…„ {month}์›” ์ผ์ • ์š”์•ฝ:**\n\n{summary}")
 
3
  import requests
4
  import streamlit as st
5
  from datetime import datetime
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ import torch
8
 
9
+ # โœ… ์•ˆ์ „ํ•œ ์บ์‹œ ๊ฒฝ๋กœ ์„ค์ •
10
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
11
  os.environ["HF_HOME"] = "/tmp/hf_cache"
12
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
 
16
  st.title("๐Ÿ“… ํ•™์‚ฌ์ผ์ • ์บ˜๋ฆฐ๋” + AI ์š”์•ฝ")
17
  st.markdown("NEIS API์—์„œ ํ•™์‚ฌ์ผ์ •์„ ๋ถˆ๋Ÿฌ์˜ค๊ณ  FullCalendar๋กœ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.")
18
 
19
+ # โœ… ๋ชจ๋ธ ๋กœ๋”ฉ ํ•จ์ˆ˜ (skt/A.X-4.0-Light)
 
 
 
 
 
20
  @st.cache_resource
21
  def load_model():
22
  token = os.environ.get("HUGGINGFACE_TOKEN")
23
+ model_id = "skt/A.X-4.0-Light"
24
  cache_dir = "/tmp/hf_cache"
25
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token, cache_dir=cache_dir)
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_id,
29
+ use_auth_token=token,
30
+ torch_dtype=torch.bfloat16,
31
+ device_map="auto",
32
+ cache_dir=cache_dir
33
+ )
34
+ model.eval()
35
+ return tokenizer, model
36
 
37
+ tokenizer, model = load_model()
38
 
39
  # โœ… ํ•™๊ต ์ •๋ณด ๊ฐ€์ ธ์˜ค๊ธฐ
40
  def get_school_info(region_code, school_name, api_key):
 
44
  school = data.get("schoolInfo", [{}])[1].get("row", [{}])[0]
45
  return school.get("SD_SCHUL_CODE"), school.get("ATPT_OFCDC_SC_CODE")
46
 
47
+ # โœ… ํ•™์‚ฌ์ผ์ • ๊ฐ€์ ธ์˜ค๊ธฐ
48
  def get_schedule(region_code, school_code, year, month, api_key):
49
  from_ymd = f"{year}{month:02}01"
50
  to_ymd = f"{year}{month:02}31"
 
52
  res = requests.get(url)
53
  data = res.json()
54
  rows = data.get("SchoolSchedule", [{}])[1].get("row", [])
 
55
  return rows
56
 
57
  # โœ… ์š”์•ฝ ์ƒ์„ฑ
58
  def summarize_schedule(rows, school_name, year):
59
  if not rows:
60
  return "์ผ์ •์ด ์—†์–ด ์š”์•ฝํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
61
+
62
  lines = []
63
  for row in rows:
64
  date = row["AA_YMD"]
 
66
  event = row["EVENT_NM"]
67
  lines.append(f"{dt}: {event}")
68
  text = "\n".join(lines)
69
+
70
  prompt = f"{school_name}๊ฐ€ {year}๋…„๋„์— ๊ฐ€์ง€๋Š” ํ•™์‚ฌ์ผ์ •์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:\n{text}\n์ฃผ์š” ์ผ์ •์„ ์š”์•ฝํ•ด์ฃผ์„ธ์š”."
71
+
72
+ messages = [
73
+ {"role": "system", "content": "๋‹น์‹ ์€ ํ•™์‚ฌ์ผ์ •์„ ์š”์•ฝํ•ด์ฃผ๋Š” AI์ž…๋‹ˆ๋‹ค."},
74
+ {"role": "user", "content": prompt}
75
+ ]
76
+
77
+ input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
78
+
79
+ with torch.no_grad():
80
+ output = model.generate(
81
+ input_ids,
82
+ max_new_tokens=256,
83
+ do_sample=False,
84
+ )
85
+
86
+ len_prompt = input_ids.shape[1]
87
+ response = tokenizer.decode(output[0][len_prompt:], skip_special_tokens=True).strip()
88
+ return response
89
 
90
  # โœ… ์ง€์—ญ/ํ•™๊ต/๋…„๋„/์›” ์„ ํƒ UI
91
  region_options = {
 
101
  month = st.selectbox("์›”", options=list(range(1, 13)), index=6)
102
  submitted = st.form_submit_button("๐Ÿ“… ํ•™์‚ฌ์ผ์ • ๋ถˆ๋Ÿฌ์˜ค๊ธฐ")
103
 
104
+ # โœ… ์ œ์ถœ ์ฒ˜๋ฆฌ
105
  if submitted:
106
  with st.spinner("์ผ์ • ๋ถˆ๋Ÿฌ์˜ค๋Š” ์ค‘..."):
107
  api_key = os.environ.get("NEIS_API_KEY", "a69e08342c8947b4a52cd72789a5ecaf")
 
113
  if not schedule_rows:
114
  st.info("ํ•ด๋‹น ์กฐ๊ฑด์˜ ํ•™์‚ฌ์ผ์ •์ด ์—†์Šต๋‹ˆ๋‹ค.")
115
  else:
 
116
  events = [
117
  {
118
  "title": row["EVENT_NM"],
 
121
  for row in schedule_rows
122
  if "AA_YMD" in row and "EVENT_NM" in row
123
  ]
 
124
  event_json = json.dumps(events, ensure_ascii=False)
125
 
126
  st.components.v1.html(f"""
 
147
  </html>
148
  """, height=650)
149
 
 
150
  with st.expander("โœจ 1๋…„์น˜ ์š”์•ฝ ๋ณด๊ธฐ", expanded=False):
151
  if st.button("๐Ÿค– ์š”์•ฝ ์ƒ์„ฑํ•˜๊ธฐ"):
152
+ with st.spinner("๋ชจ๋ธ์ด ์š”์•ฝ ์ค‘..."):
153
  summary = summarize_schedule(schedule_rows, school_name, year)
154
  st.success("์š”์•ฝ ์™„๋ฃŒ!")
155
  st.markdown(f"**{school_name} {year}๋…„ {month}์›” ์ผ์ • ์š”์•ฝ:**\n\n{summary}")