jonghhhh commited on
Commit
2baf61e
ยท
verified ยท
1 Parent(s): 812af0e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +147 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,149 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, time, json, re, requests
2
+ from io import BytesIO
3
+ from PIL import Image
4
  import streamlit as st
5
+ from google import genai
6
 
7
+ st.set_page_config(page_title="๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ LLM ๋ฐ๋ชจ (Gemma 3 27B)", page_icon="๐Ÿค–", layout="wide")
8
+
9
+ # ===== Sidebar: ์„ค์ • =====
10
+ st.sidebar.title("๐Ÿ”ง ์„ค์ •")
11
+ api_key = st.sidebar.text_input("GOOGLE_API_KEY", value=os.getenv("GOOGLE_API_KEY", ""), type="password")
12
+ model_name = st.sidebar.selectbox("๋ชจ๋ธ", ["gemma-3-27b-it"], index=0)
13
+ rate_delay = st.sidebar.number_input("ํ˜ธ์ถœ๊ฐ„ ๋Œ€๊ธฐ(์ดˆ)", value=1.0, step=0.5, min_value=0.0)
14
+
15
+ if not api_key:
16
+ st.warning("์‚ฌ์ด๋“œ๋ฐ”์— GOOGLE_API_KEY๋ฅผ ์ž…๋ ฅํ•˜๊ฑฐ๋‚˜ ํ™˜๊ฒฝ๋ณ€์ˆ˜๋กœ ์„ค์ •ํ•ด ์ฃผ์„ธ์š”.")
17
+ st.stop()
18
+
19
+ client = genai.Client(api_key=api_key)
20
+
21
+ # ===== ์œ ํ‹ธ =====
22
+ RETRY_LIMIT = 3
23
+
24
+ def load_image(source: str) -> Image.Image:
25
+ """URL/๋กœ์ปฌ ๋ชจ๋‘ ์ง€์› + ํˆฌ๋ช… ๋ฐฐ๊ฒฝ ๋ณด์ •."""
26
+ session = requests.Session()
27
+ session.headers.update({"User-Agent": "Mozilla/5.0"})
28
+ last_err = None
29
+ for _ in range(RETRY_LIMIT):
30
+ try:
31
+ if source.startswith(("http://", "https://")):
32
+ resp = session.get(source, timeout=10)
33
+ resp.raise_for_status()
34
+ img = Image.open(BytesIO(resp.content))
35
+ else:
36
+ img = Image.open(source)
37
+
38
+ if img.mode in ("RGBA", "LA", "P"):
39
+ bg = Image.new("RGB", img.size, (255, 255, 255))
40
+ img = img.convert("RGBA")
41
+ bg.paste(img, mask=img.split()[-1] if len(img.split()) > 3 else None)
42
+ img = bg
43
+ else:
44
+ img = img.convert("RGB")
45
+ return img
46
+ except Exception as e:
47
+ last_err = e
48
+ time.sleep(0.5)
49
+ raise RuntimeError(f"์ด๋ฏธ์ง€ ๋กœ๋“œ ์‹คํŒจ: {last_err}")
50
+
51
+ RESET_PROMPT = "์ด์ „ ๋Œ€ํ™”๋Š” ๋ฌด์‹œํ•˜๊ณ , ์•„๋ž˜ ์ง€์‹œ์—๋งŒ ์‘๋‹ตํ•˜์„ธ์š”.\n\n"
52
+
53
+ def try_parse_json(raw_text: str):
54
+ """์‘๋‹ต์—์„œ JSON ์ถ”์ถœ ์‹œ๋„."""
55
+ pattern = r'```(?:json)?\s*(\{[\s\S]*?\})\s*```|(\{[\s\S]*?\})'
56
+ for g1, g2 in re.findall(pattern, raw_text):
57
+ cand = (g1 or g2).strip()
58
+ try:
59
+ return json.loads(cand)
60
+ except json.JSONDecodeError:
61
+ pass
62
+ cleaned = re.sub(r'```json|```', '', raw_text).strip()
63
+ try:
64
+ return json.loads(cleaned)
65
+ except json.JSONDecodeError:
66
+ return None
67
+
68
+ def infer_text(article: str, prompt: str):
69
+ contents = [RESET_PROMPT + prompt.strip() + "\n\n" + article.strip()]
70
+ resp = client.models.generate_content(model=model_name, contents=contents)
71
+ time.sleep(rate_delay)
72
+ text = (resp.text or "").strip()
73
+ parsed = try_parse_json(text)
74
+ return parsed if parsed is not None else {"text": text}
75
+
76
+ def infer_image(image: Image.Image, prompt: str):
77
+ contents = [RESET_PROMPT + prompt.strip(), image]
78
+ resp = client.models.generate_content(model=model_name, contents=contents)
79
+ time.sleep(rate_delay)
80
+ text = (resp.text or "").strip()
81
+ parsed = try_parse_json(text)
82
+ return parsed if parsed is not None else {"text": text}
83
+
84
+ # ===== ๊ธฐ๋ณธ ํ”„๋กฌํ”„ํŠธ(๊ฐ„๊ฒฐ ๋ฒ„์ „) =====
85
+ TEXT_PROMPT = (
86
+ "๋‹น์‹ ์€ ๊ธฐ์‚ฌ ์ •๋ณด์› ๋ถ„์„ ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค. ๋‹ค์Œ ๊ธฐ์‚ฌ์—์„œ ์ •๋ณด์›์„ ์ถ”์ถœํ•˜๊ณ , "
87
+ "์ •๋ณด์›๋ณ„ ๋ฌ˜์‚ฌ ํ”„๋ ˆ์ž„์„ ๊ธ์ •/์ค‘๋ฆฝ/๋ถ€์ •์œผ๋กœ ํŒ์ •ํ•ด JSON์œผ๋กœ๋งŒ ์ถœ๋ ฅํ•˜์„ธ์š”.\n"
88
+ '์˜ˆ์‹œ: {"sources": ["์ •๋ณด์›A","์ •๋ณด์›B"], "frames": ["์ค‘๋ฆฝ","๋ถ€์ •"]}'
89
+ )
90
+
91
+ IMAGE_PROMPT = (
92
+ "๋‹น์‹ ์€ ๋ณด๋„์‚ฌ์ง„ ๋ถ„์„ ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค. ์ œ๊ณต๋œ ์‚ฌ์ง„์—์„œ Donald Trump ์กด์žฌ ์—ฌ๋ถ€๋ฅผ ํŒ๋‹จํ•˜๊ณ  "
93
+ "๊ฐ์ •(emotion: positive/negative/neutral)๊ณผ ์—ญ๋™์„ฑ(dynamism: high/medium/low)์„ ํ‰๊ฐ€ํ•ด "
94
+ 'JSON์œผ๋กœ๋งŒ ์ถœ๋ ฅํ•˜์„ธ์š”. Trump๊ฐ€ ์—†์œผ๋ฉด {"trump_present": false}๋งŒ ๋ฐ˜ํ™˜ํ•˜์„ธ์š”.'
95
+ )
96
+
97
+ # ===== UI =====
98
+ st.title("๐Ÿค– ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ LLM ๋ฐ๋ชจ (Gemma 3 27B)")
99
+ tab_text, tab_img = st.tabs(["๐Ÿ“ ํ…์ŠคํŠธ ๋ถ„์„", "๐Ÿ–ผ๏ธ ์ด๋ฏธ์ง€ ๋ถ„์„"])
100
+
101
+ with tab_text:
102
+ st.subheader("๊ธฐ์‚ฌ ํ…์ŠคํŠธ โ†’ ์ •๋ณด์› & ํ”„๋ ˆ์ž„ ํŒ์ •")
103
+ article = st.text_area(
104
+ "๊ธฐ์‚ฌ ๋ณธ๋ฌธ", height=180,
105
+ value="Donald Trump์™€ Nancy Pelosi๊ฐ€ ํšŒ์˜์žฅ์—์„œ ๊ฒฉ๋ ฌํžˆ ๋…ผ์Ÿ์„ ๋ฒŒ์˜€๋‹ค..."
106
+ )
107
+ user_prompt = st.text_area("ํ”„๋กฌํ”„ํŠธ(์˜ต์…˜)", value=TEXT_PROMPT, height=120)
108
+ if st.button("ํ…์ŠคํŠธ ๋ถ„์„ ์‹คํ–‰", type="primary", use_container_width=True):
109
+ if not article.strip():
110
+ st.error("๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”.")
111
+ else:
112
+ with st.spinner("๋ถ„์„ ์ค‘..."):
113
+ try:
114
+ result = infer_text(article, user_prompt or TEXT_PROMPT)
115
+ st.success("์™„๋ฃŒ")
116
+ st.json(result)
117
+ st.download_button("๊ฒฐ๊ณผ JSON ๋‹ค์šด๋กœ๋“œ", data=json.dumps(result, ensure_ascii=False, indent=2),
118
+ file_name="text_result.json", mime="application/json")
119
+ except Exception as e:
120
+ st.error(f"์˜ค๋ฅ˜: {e}")
121
+
122
+ with tab_img:
123
+ st.subheader("๋ณด๋„์‚ฌ์ง„ โ†’ ์ธ๋ฌผ ์กด์žฌยท๊ฐ์ •ยท์—ญ๋™์„ฑ ๋ถ„์„")
124
+ col1, col2 = st.columns(2)
125
+ with col1:
126
+ img_url = st.text_input("์ด๋ฏธ์ง€ URL")
127
+ with col2:
128
+ file = st.file_uploader("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ", type=["jpg","jpeg","png","webp","gif","bmp"])
129
+
130
+ img_prompt = st.text_area("ํ”„๋กฌํ”„ํŠธ(์˜ต์…˜)", value=IMAGE_PROMPT, height=120)
131
+ if st.button("์ด๋ฏธ์ง€ ๋ถ„์„ ์‹คํ–‰", type="primary", use_container_width=True):
132
+ try:
133
+ if file is not None:
134
+ image = Image.open(io.BytesIO(file.read())).convert("RGB")
135
+ elif img_url.strip():
136
+ image = load_image(img_url.strip())
137
+ else:
138
+ st.error("์ด๋ฏธ์ง€ URL์„ ์ž…๋ ฅํ•˜๊ฑฐ๋‚˜ ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•ด ์ฃผ์„ธ์š”.")
139
+ st.stop()
140
+
141
+ st.image(image, caption="์ž…๋ ฅ ์ด๋ฏธ์ง€", use_container_width=True)
142
+ with st.spinner("๋ถ„์„ ์ค‘..."):
143
+ result = infer_image(image, img_prompt or IMAGE_PROMPT)
144
+ st.success("์™„๋ฃŒ")
145
+ st.json(result)
146
+ st.download_button("๊ฒฐ๊ณผ JSON ๋‹ค์šด๋กœ๋“œ", data=json.dumps(result, ensure_ascii=False, indent=2),
147
+ file_name="image_result.json", mime="application/json")
148
+ except Exception as e:
149
+ st.error(f"์˜ค๋ฅ˜: {e}")