RFx_FirstTry / app.py
dohyune's picture
Upload app.py
e6afbcf verified
"""
RFx ๋ฌธ์„œ ๋ถ„์„ AI ์—์ด์ „ํŠธ (PDF Text Highlighting)
PDF ํ…์ŠคํŠธ์— ์ง์ ‘ ํ•˜์ด๋ผ์ดํŠธ ํ‘œ์‹œ
"""
import streamlit as st
import fitz # PyMuPDF
import chromadb
from sentence_transformers import SentenceTransformer, util
import requests
import os
import re
import shutil
from collections import Counter
import numpy as np
from typing import List, Dict, Tuple
import base64
GROK_API_KEY = os.getenv("GROK_API_KEY")
GROK_API_BASE = "https://api.x.ai/v1"
CHROMA_DIR = "./chroma_db"
EMBEDDING_MODEL = 'jhgan/ko-sroberta-multitask'
st.set_page_config(
page_title="RFx ๋ฌธ์„œ ๋ถ„์„ AI",
page_icon="๐Ÿ“„",
layout="wide",
initial_sidebar_state="collapsed"
)
st.markdown("""
<style>
.main-title {
font-size: 1.8rem;
font-weight: bold;
color: #1E3A8A;
margin-bottom: 1rem;
text-align: center;
}
.source-box {
background: #F1F5F9;
padding: 1rem;
border-radius: 0.5rem;
margin: 0.5rem 0;
border-left: 3px solid #3B82F6;
}
.source-title {
font-weight: bold;
color: #1E40AF;
margin-bottom: 0.5rem;
}
.keyword-badge {
display: inline-block;
background: #DBEAFE;
color: #1E40AF;
padding: 0.2rem 0.6rem;
border-radius: 0.3rem;
margin: 0.2rem;
font-size: 0.85rem;
}
.pdf-container {
border: 2px solid #E2E8F0;
border-radius: 0.5rem;
padding: 0.5rem;
height: 800px;
overflow-y: auto;
background: white;
}
.page-indicator {
background: #3B82F6;
color: white;
padding: 0.3rem 0.8rem;
border-radius: 1rem;
font-size: 0.85rem;
display: inline-block;
margin: 0.2rem;
}
.highlight-indicator {
background: #FEF08A;
color: #854D0E;
padding: 0.5rem 1rem;
border-radius: 0.5rem;
margin: 0.5rem 0;
font-weight: bold;
border-left: 4px solid #EAB308;
}
</style>
""", unsafe_allow_html=True)
def init_session():
if 'processed' not in st.session_state:
st.session_state.processed = False
if 'vector_db' not in st.session_state:
st.session_state.vector_db = None
if 'embedder' not in st.session_state:
st.session_state.embedder = None
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'doc_metadata' not in st.session_state:
st.session_state.doc_metadata = {}
if 'pdf_bytes' not in st.session_state:
st.session_state.pdf_bytes = None
if 'pdf_pages_text' not in st.session_state:
st.session_state.pdf_pages_text = {}
if 'current_highlights' not in st.session_state:
st.session_state.current_highlights = []
def extract_text_from_pdf(pdf_file) -> Tuple[List[str], List[Dict], bytes, Dict]:
pdf_bytes = pdf_file.read()
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
chunks = []
metadata_list = []
pages_text = {}
CHUNK_SIZE = 300
OVERLAP_SIZE = 60
for page_num in range(len(doc)):
page = doc[page_num]
text = page.get_text("text")
pages_text[page_num + 1] = text
if not text.strip():
continue
lines = [line.strip() for line in text.split('\n') if line.strip()]
cleaned_text = '\n'.join(lines)
sentences = re.split(r'([.!?]\s+|\n{2,})', cleaned_text)
sentences = [s for s in sentences if s.strip()]
current_chunk = ""
current_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if current_length + sentence_length > CHUNK_SIZE and current_chunk:
chunks.append(current_chunk.strip())
metadata_list.append({
"page": page_num + 1,
"source": pdf_file.name,
"chunk_type": "paragraph"
})
overlap_text = current_chunk[-OVERLAP_SIZE:] if len(current_chunk) > OVERLAP_SIZE else current_chunk
current_chunk = overlap_text + sentence
current_length = len(current_chunk)
else:
current_chunk += sentence
current_length += sentence_length
if current_chunk.strip():
chunks.append(current_chunk.strip())
metadata_list.append({
"page": page_num + 1,
"source": pdf_file.name,
"chunk_type": "paragraph"
})
doc.close()
return chunks, metadata_list, pdf_bytes, pages_text
@st.cache_resource
def load_embedding_model():
return SentenceTransformer(EMBEDDING_MODEL)
def create_vector_db(chunks: List[str], metadata_list: List[Dict]):
embedder = load_embedding_model()
if os.path.exists(CHROMA_DIR):
try:
shutil.rmtree(CHROMA_DIR)
except Exception:
pass
client = chromadb.PersistentClient(
path=CHROMA_DIR,
settings=chromadb.Settings(
anonymized_telemetry=False,
allow_reset=True,
is_persistent=True
)
)
try:
collection = client.get_or_create_collection(
name="rfx_docs",
metadata={"hnsw:space": "cosine"}
)
except Exception:
try:
client.delete_collection("rfx_docs")
except Exception:
pass
collection = client.create_collection(
name="rfx_docs",
metadata={"hnsw:space": "cosine"}
)
batch_size = 32
all_embeddings = []
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
embeddings = embedder.encode(batch, show_progress_bar=False, convert_to_numpy=True)
all_embeddings.extend(embeddings)
ids = [f"doc_{i}" for i in range(len(chunks))]
collection.add(
embeddings=[emb.tolist() for emb in all_embeddings],
documents=chunks,
metadatas=metadata_list,
ids=ids
)
return collection, embedder
def extract_keywords(text: str, top_n: int = 5) -> List[str]:
words_with_numbers = re.findall(r'[๊ฐ€-ํžฃ]*\d+[๊ฐ€-ํžฃ]*', text)
words = re.findall(r'[๊ฐ€-ํžฃ]{2,}', text)
stopwords = {
'๊ฒƒ', '๋“ฑ', '๋ฐ', '๊ทธ', '์ด', '์ €', '์ˆ˜', '๋•Œ', '์ค‘', '๋‚ด', '๋…„', '์›”', '์ผ',
'๊ฒฝ์šฐ', '๋Œ€ํ•œ', 'ํ†ตํ•ด', '์œ„ํ•ด', '๊ด€๋ จ', '์žˆ๋Š”', 'ํ•˜๋Š”', '๋˜๋Š”', '์ด๋Ÿฐ', '์ €๋Ÿฐ',
'์–ด๋–ค', '๋ฌด์Šจ', '์–ด๋А', '๋ˆ„๊ตฌ', '์–ธ์ œ', '์–ด๋””', '๋ฌด์—‡', '์–ด๋–ป๊ฒŒ', '์™œ',
'์•Œ๋ ค', '์„ค๋ช…', '๋งํ•ด', '๋Œ€ํ•ด', '๊ด€ํ•˜์—ฌ', '์žˆ๋‚˜์š”', '์ธ๊ฐ€์š”', '๋ฌด์—‡์ธ๊ฐ€์š”',
'์–ผ๋งˆ', '์ž…๋‹ˆ๊นŒ', 'ํ•ฉ๋‹ˆ๊นŒ'
}
important_keywords = {
'๊ธˆ์•ก', '๊ฐ€๊ฒฉ', '๋น„์šฉ', '์˜ˆ์‚ฐ', '์„ค๊ณ„', '์‚ฌ์—…', '๊ณผ์—…', '๊ณ„์•ฝ',
'๊ณต์‚ฌ', '์šฉ์—ญ', '์ œ์•ˆ', '์ž…์ฐฐ', '๋‚™์ฐฐ', '๊ฒฌ์ ', '๋‹จ๊ฐ€'
}
filtered_words = [w for w in words if w not in stopwords and len(w) >= 2]
word_freq = Counter(filtered_words)
for word in word_freq:
if word in important_keywords:
word_freq[word] += 5
result = []
result.extend([w for w in words_with_numbers if w])
for word, _ in word_freq.most_common(top_n * 2):
if word not in result:
result.append(word)
if len(result) >= top_n:
break
return result[:top_n]
def rewrite_query(query: str) -> Dict[str, any]:
original = query.strip()
cleaned = re.sub(r'[?!,.~]', '', original)
keywords = extract_keywords(cleaned, top_n=7)
variations = []
variations.append(original)
if keywords:
if len(keywords) >= 2:
variations.append(' '.join(keywords[:2]))
if len(keywords) >= 3:
variations.append(' '.join(keywords[:3]))
for kw in keywords[:3]:
if kw not in variations:
variations.append(kw)
synonym_map = {
'๊ธˆ์•ก': ['๊ฐ€๊ฒฉ', '๋น„์šฉ', '์˜ˆ์‚ฐ'],
'์„ค๊ณ„': ['๋””์ž์ธ', '๊ณ„ํš'],
'์‚ฌ์—…': ['ํ”„๋กœ์ ํŠธ', '๊ณผ์—…'],
}
for keyword in keywords[:2]:
if keyword in synonym_map:
for syn in synonym_map[keyword]:
combined = original.replace(keyword, syn)
if combined not in variations:
variations.append(combined)
break
seen = set()
unique_variations = []
for v in variations:
if v not in seen and v.strip():
seen.add(v)
unique_variations.append(v)
return {
'original': original,
'cleaned': cleaned,
'keywords': keywords,
'variations': unique_variations[:7]
}
def search_with_multiple_queries(queries: List[str], collection, embedder, top_k: int = 5) -> Dict:
all_results = []
seen_ids = set()
for query in queries:
query_embedding = embedder.encode([query], convert_to_numpy=True)[0]
results = collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=min(top_k * 5, 30),
include=["documents", "metadatas", "distances"]
)
for i, doc_id in enumerate(results['ids'][0]):
if doc_id not in seen_ids:
seen_ids.add(doc_id)
all_results.append({
'id': doc_id,
'document': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i],
'query': query
})
all_results.sort(key=lambda x: x['distance'])
top_results = all_results[:top_k]
return {
'documents': [[r['document'] for r in top_results]],
'metadatas': [[r['metadata'] for r in top_results]],
'distances': [[r['distance'] for r in top_results]],
'queries_used': queries,
'total_found': len(all_results)
}
def rerank_results(query: str, search_results: Dict, embedder, keywords: List[str]) -> Dict:
docs = search_results['documents'][0]
metas = search_results['metadatas'][0]
distances = search_results['distances'][0]
if not docs:
return {
'documents': [[]],
'metadatas': [[]],
'distances': [[]],
'scores': []
}
query_embedding = embedder.encode([query], convert_to_numpy=True)[0]
doc_embeddings = embedder.encode(docs, convert_to_numpy=True)
similarities = util.cos_sim(query_embedding, doc_embeddings)[0].cpu().numpy()
keyword_scores = []
for doc in docs:
doc_lower = doc.lower()
score = sum(1 for kw in keywords if kw.lower() in doc_lower)
keyword_scores.append(score)
if max(keyword_scores) > 0:
keyword_scores = [s / max(keyword_scores) for s in keyword_scores]
numeric_query_terms = ['๊ธˆ์•ก', '์˜ˆ์‚ฐ', '๊ฐ€๊ฒฉ', '๋น„์šฉ', '๋‹จ๊ฐ€']
is_numeric_query = any(term in query for term in numeric_query_terms)
if is_numeric_query:
money_patterns = [
r'\d{1,3}(?:,\d{3})+์›',
r'\d+๋งŒ์›',
r'\d+์–ต์›',
r'\(์ผ๊ธˆ\s*[^)]+\)'
]
numeric_scores = []
for doc in docs:
score = 0
for pattern in money_patterns:
if re.search(pattern, doc):
score = 1
break
numeric_scores.append(score)
if max(numeric_scores) > 0:
numeric_scores = [s / max(numeric_scores) for s in numeric_scores]
else:
numeric_scores = [0.0 for _ in numeric_scores]
final_scores = [
0.6 * sim + 0.25 * kw + 0.15 * num
for sim, kw, num in zip(similarities, keyword_scores, numeric_scores)
]
else:
final_scores = [0.7 * sim + 0.3 * kw for sim, kw in zip(similarities, keyword_scores)]
ranked_indices = np.argsort(final_scores)[::-1]
return {
'documents': [[docs[i] for i in ranked_indices]],
'metadatas': [[metas[i] for i in ranked_indices]],
'distances': [[distances[i] for i in ranked_indices]],
'scores': [final_scores[i] for i in ranked_indices]
}
def build_context(search_results: Dict, max_length: int = 3000) -> str:
context_parts = []
current_length = 0
docs = search_results['documents'][0]
metas = search_results['metadatas'][0]
for i, (doc, meta) in enumerate(zip(docs, metas), 1):
part = f"[๋ฌธ์„œ {i}] (ํŽ˜์ด์ง€ {meta['page']})\n{doc}\n"
part_length = len(part)
if current_length + part_length > max_length:
remaining = max_length - current_length
if remaining > 200:
part = f"[๋ฌธ์„œ {i}] (ํŽ˜์ด์ง€ {meta['page']})\n{doc[:remaining-50]}...\n"
context_parts.append(part)
break
context_parts.append(part)
current_length += part_length
return "\n".join(context_parts)
def generate_answer(query: str, search_results: Dict, api_key: str) -> str:
context = build_context(search_results, max_length=4000)
system_prompt = """๋‹น์‹ ์€ RFx ๋ฌธ์„œ ์ „๋ฌธ ๋ถ„์„๊ฐ€์ž…๋‹ˆ๋‹ค.
**์ค‘์š” ์›์น™:**
1. ์ œ๊ณต๋œ ๋ฌธ์„œ๋ฅผ **๋งค์šฐ ๊ผผ๊ผผํžˆ** ์ฝ๊ณ  ์ •ํ™•ํ•œ ์ •๋ณด๋ฅผ ์ฐพ์œผ์„ธ์š”
2. ์ˆซ์ž, ๊ธˆ์•ก, ๋‚ ์งœ ๋“ฑ ๊ตฌ์ฒด์ ์ธ ์ •๋ณด๋ฅผ ์šฐ์„ ์ ์œผ๋กœ ์ฐพ์œผ์„ธ์š”
3. ๋ฌธ์„œ์— ์ •๋ณด๊ฐ€ ์žˆ๋Š”๋ฐ๋„ "์—†๋‹ค"๊ณ  ํ•˜์ง€ ๋งˆ์„ธ์š”
4. ๋‹ต๋ณ€ ์‹œ ๋ฐ˜๋“œ์‹œ [๋ฌธ์„œ N, ํŽ˜์ด์ง€ X] ํ˜•ํƒœ๋กœ ์ถœ์ฒ˜ ๋ช…์‹œ
5. ์• ๋งคํ•œ ํ‘œํ˜„ ๋Œ€์‹  ๊ตฌ์ฒด์ ์ธ ์ˆ˜์น˜๋ฅผ ์ œ๊ณตํ•˜์„ธ์š”
**๋‹ต๋ณ€ ํ˜•์‹:**
- ํ•ต์‹ฌ ๋‹ต๋ณ€์„ ๋จผ์ € ๋ช…ํ™•ํ•˜๊ฒŒ ์ œ์‹œ
- ์ถœ์ฒ˜ ๋ช…์‹œ (ํŽ˜์ด์ง€ ๋ฒˆํ˜ธ ํฌํ•จ)
- ํ•„์š”์‹œ ์ถ”๊ฐ€ ๊ด€๋ จ ์ •๋ณด ์ œ๊ณต"""
user_prompt = f"""๋‹ค์Œ ๋ฌธ์„œ๋“ค์„ **๋งค์šฐ ๊ผผ๊ผผํžˆ** ์ฝ๊ณ  ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜์„ธ์š”.
<๋ฌธ์„œ>
{context}
</๋ฌธ์„œ>
<์งˆ๋ฌธ>
{query}
</์งˆ๋ฌธ>
**์ค‘์š”**:
- ๋ฌธ์„œ๋ฅผ ์ฒ˜์Œ๋ถ€ํ„ฐ ๋๊นŒ์ง€ ์ฃผ์˜ ๊นŠ๊ฒŒ ์ฝ์œผ์„ธ์š”
- ์ˆซ์ž, ๊ธˆ์•ก ๋“ฑ ๊ตฌ์ฒด์ ์ธ ์ •๋ณด๋ฅผ ์ฐพ์œผ์„ธ์š”
- ์ฐพ์€ ์ •๋ณด๋Š” ์ •ํ™•ํžˆ ์ธ์šฉํ•˜์„ธ์š”
- ์ •๋ง๋กœ ๋ฌธ์„œ์— ์—†๋Š” ๊ฒฝ์šฐ์—๋งŒ "์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค"๋ผ๊ณ  ํ•˜์„ธ์š”"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
payload = {
"model": "grok-3",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"temperature": 0.1,
"max_tokens": 2000,
"stream": False
}
try:
response = requests.post(
f"{GROK_API_BASE}/chat/completions",
headers=headers,
json=payload,
timeout=30
)
if response.status_code != 200:
error_detail = ""
try:
error_data = response.json()
error_detail = error_data.get('error', {}).get('message', '')
except Exception:
error_detail = response.text
return f"โŒ API ์˜ค๋ฅ˜ (์ฝ”๋“œ: {response.status_code})\n\n{error_detail}"
result = response.json()
return result["choices"][0]["message"]["content"]
except Exception as e:
return f"โŒ ์˜ค๋ฅ˜: {str(e)}"
def highlight_text_in_pdf(pdf_bytes: bytes, highlight_info: List[Dict]) -> bytes:
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
for item in highlight_info:
page_num = item['page'] - 1
search_text = item['text']
if page_num >= len(doc):
continue
page = doc[page_num]
text_variations = [
search_text,
search_text.replace(' ', ''),
search_text.replace(',', ''),
]
for text_var in text_variations:
text_instances = page.search_for(text_var)
for inst in text_instances:
highlight = page.add_highlight_annot(inst)
highlight.set_colors(stroke=[1, 1, 0])
highlight.update()
output_bytes = doc.tobytes()
doc.close()
return output_bytes
def extract_highlight_texts(documents: List[str], keywords: List[str]) -> List[str]:
highlight_texts = []
for doc in documents:
money_patterns = [
r'\d{1,3}(?:,\d{3})+์›',
r'\d+๋งŒ์›',
r'\d+์–ต์›',
r'\(์ผ๊ธˆ\s*[^)]+\)',
]
for pattern in money_patterns:
matches = re.findall(pattern, doc)
highlight_texts.extend(matches)
date_patterns = [
r'\d{4}[๋…„.]\d{1,2}[์›”.]\d{1,2}์ผ?',
r'\d{2}\.\d{2}\.\d{2}',
]
for pattern in date_patterns:
matches = re.findall(pattern, doc)
highlight_texts.extend(matches)
for keyword in keywords:
if keyword in doc:
sentences = re.split(r'[.!?]\s+', doc)
for sent in sentences:
if keyword in sent and len(sent) < 100:
highlight_texts.append(sent.strip())
unique_texts = list(set(highlight_texts))
unique_texts.sort(key=len)
return unique_texts[:10]
def render_pdf_with_highlights(pdf_bytes: bytes, highlight_info: List[Dict]):
highlighted_pdf = highlight_text_in_pdf(pdf_bytes, highlight_info)
doc = fitz.open(stream=highlighted_pdf, filetype="pdf")
highlighted_pages = set(h['page'] for h in highlight_info)
pdf_html = '<div class="pdf-container">'
for page_num in range(len(doc)):
page = doc[page_num]
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
img_data = pix.tobytes("png")
img_base64 = base64.b64encode(img_data).decode()
pdf_html += '<div style="margin-bottom: 2rem; position: relative;">'
pdf_html += f'<div style="background: #3B82F6; color: white; padding: 0.5rem; margin-bottom: 0.5rem; border-radius: 0.3rem; font-weight: bold;">๐Ÿ“„ ํŽ˜์ด์ง€ {page_num + 1}</div>'
if (page_num + 1) in highlighted_pages:
page_highlights = [h for h in highlight_info if h['page'] == page_num + 1]
highlight_texts = ', '.join([f'"{h["text"][:30]}..."' for h in page_highlights[:3]])
pdf_html += f'<div class="highlight-indicator">โญ ํ•˜์ด๋ผ์ดํŠธ: {highlight_texts}</div>'
pdf_html += f'<img src="data:image/png;base64,{img_base64}" style="width: 100%; border: 1px solid #E2E8F0; border-radius: 0.3rem; box-shadow: 0 1px 3px rgba(0,0,0,0.1);" />'
pdf_html += '</div>'
pdf_html += '</div>'
doc.close()
return pdf_html
def main():
init_session()
st.markdown('<div class="main-title">๐Ÿ“„ RFx ๋ฌธ์„œ ๋ถ„์„ AI ์—์ด์ „ํŠธ</div>', unsafe_allow_html=True)
with st.sidebar:
st.header("โš™๏ธ ์„ค์ •")
grok_key = st.text_input("Grok API Key", value=GROK_API_KEY or "", type="password")
if grok_key:
os.environ["GROK_API_KEY"] = grok_key
st.session_state.grok_key = grok_key
st.divider()
if st.button("๐Ÿ”„ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์ดˆ๊ธฐํ™”", help="ChromaDB ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ ํด๋ฆญ"):
if os.path.exists(CHROMA_DIR):
try:
shutil.rmtree(CHROMA_DIR)
st.success("โœ… ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์ดˆ๊ธฐํ™” ์™„๋ฃŒ!")
st.session_state.processed = False
st.session_state.vector_db = None
st.rerun()
except Exception as e:
st.error(f"์ดˆ๊ธฐํ™” ์‹คํŒจ: {str(e)}")
st.divider()
st.subheader("๐Ÿ“ค ๋ฌธ์„œ ์—…๋กœ๋“œ")
uploaded_file = st.file_uploader("PDF ํŒŒ์ผ ์„ ํƒ", type=['pdf'])
if uploaded_file:
if st.button("๐Ÿ“„ ๋ฌธ์„œ ์ฒ˜๋ฆฌ", type="primary", disabled=st.session_state.get('processing', False)):
if not grok_key:
st.error("โš ๏ธ Grok API ํ‚ค๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”!")
return
st.session_state.processing = True
with st.spinner("๐Ÿ“„ ๋ฌธ์„œ ์ฒ˜๋ฆฌ ์ค‘..."):
try:
chunks, metadata_list, pdf_bytes, pages_text = extract_text_from_pdf(uploaded_file)
st.info(f"๐Ÿ“‘ {len(chunks)}๊ฐœ ์ฒญํฌ ์ถ”์ถœ ์™„๋ฃŒ")
with st.expander("๐Ÿ“ ์ถ”์ถœ๋œ ํ…์ŠคํŠธ ์ƒ˜ํ”Œ", expanded=False):
if chunks:
st.text(f"์ฒซ ๋ฒˆ์งธ ์ฒญํฌ (์ด {len(chunks[0])}์ž):")
st.code(chunks[0][:500] + "..." if len(chunks[0]) > 500 else chunks[0])
with st.spinner("๐Ÿ”ง ๋ฒกํ„ฐ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์ƒ์„ฑ ์ค‘..."):
collection, embedder = create_vector_db(chunks, metadata_list)
st.session_state.vector_db = collection
st.session_state.embedder = embedder
st.session_state.pdf_bytes = pdf_bytes
st.session_state.pdf_pages_text = pages_text
st.session_state.processed = True
st.session_state.doc_metadata = {
"filename": uploaded_file.name,
"chunks": len(chunks),
"pages": len(set(m['page'] for m in metadata_list))
}
st.success("โœ… ๋ฌธ์„œ ์ฒ˜๋ฆฌ ์™„๋ฃŒ!")
except Exception as e:
st.error(f"์˜ค๋ฅ˜: {str(e)}")
finally:
st.session_state.processing = False
st.divider()
if st.session_state.processed:
st.subheader("๐Ÿ“Š ๋ฌธ์„œ ์ •๋ณด")
meta = st.session_state.doc_metadata
st.write(f"**ํŒŒ์ผ๋ช…:** {meta['filename']}")
st.write(f"**ํŽ˜์ด์ง€:** {meta['pages']}ํŽ˜์ด์ง€")
st.write(f"**์ฒญํฌ:** {meta['chunks']}๊ฐœ")
if st.button("๐Ÿ—‘๏ธ ์ฑ„ํŒ… ์ดˆ๊ธฐํ™”"):
st.session_state.chat_history = []
st.session_state.current_highlights = []
st.rerun()
if not st.session_state.processed:
st.info("๐Ÿ‘ˆ ์™ผ์ชฝ ์‚ฌ์ด๋“œ๋ฐ”์—์„œ PDF ๋ฌธ์„œ๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”")
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("### ๐Ÿ“„ PDF ๋ทฐ์–ด\n์›๋ณธ ๋ฌธ์„œ ํ™•์ธ")
with col2:
st.markdown("### ๐ŸŽจ ํ•˜์ด๋ผ์ดํŠธ\nํ•ต์‹ฌ ๋‚ด์šฉ ๊ฐ•์กฐ")
with col3:
st.markdown("### ๐Ÿ’ฌ AI ์ฑ—๋ด‡\n์ •ํ™•ํ•œ ๋‹ต๋ณ€")
else:
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("### ๐Ÿ“„ ๋ฌธ์„œ ๋ทฐ์–ด")
if st.session_state.pdf_bytes:
pdf_html = render_pdf_with_highlights(
st.session_state.pdf_bytes,
st.session_state.current_highlights
)
st.markdown(pdf_html, unsafe_allow_html=True)
with col2:
st.markdown("### ๐Ÿ’ฌ AI ์ฑ—๋ด‡")
chat_container = st.container()
with chat_container:
for msg in st.session_state.chat_history:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if msg["role"] == "assistant" and "sources" in msg:
with st.expander("๐Ÿ“š ์ฐธ์กฐ ๋ฌธ์„œ"):
for i, (doc, meta) in enumerate(zip(
msg["sources"]["docs"],
msg["sources"]["metas"]
), 1):
score = msg["sources"]["scores"][i-1] if "scores" in msg["sources"] else None
score_text = f" (๊ด€๋ จ๋„: {score:.2%})" if score else ""
st.markdown(f"""
<div class="source-box">
<div class="source-title">
<span class="page-indicator">ํŽ˜์ด์ง€ {meta['page']}</span>
{score_text}
</div>
<div style="font-size: 0.9rem; color: #475569;">
{doc[:300]}{'...' if len(doc) > 300 else ''}
</div>
</div>
""", unsafe_allow_html=True)
if prompt := st.chat_input("์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”...", disabled=st.session_state.get('processing', False)):
if not st.session_state.get('grok_key'):
st.error("โš ๏ธ Grok API ํ‚ค๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”!")
return
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("assistant"):
with st.spinner("๐Ÿ” ๊ฒ€์ƒ‰ ๋ฐ ๋ถ„์„ ์ค‘..."):
try:
query_info = rewrite_query(prompt)
with st.expander("๐Ÿ” ๊ฒ€์ƒ‰ ๋””๋ฒ„๊ทธ ์ •๋ณด", expanded=False):
st.write("**์ถ”์ถœ๋œ ํ‚ค์›Œ๋“œ:**", query_info['keywords'])
st.write("**๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ ๋ณ€ํ˜•:**", query_info['variations'])
search_results = search_with_multiple_queries(
query_info['variations'],
st.session_state.vector_db,
st.session_state.embedder,
top_k=7
)
with st.expander("๐Ÿ“„ ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ ๋‚ด์šฉ", expanded=False):
st.write(f"**์ด {search_results.get('total_found', 0)}๊ฐœ ๋ฌธ์„œ ๋ฐœ๊ฒฌ**")
for i, doc in enumerate(search_results['documents'][0][:3], 1):
st.write(f"**๋ฌธ์„œ {i}:**")
st.text(doc[:300] + "..." if len(doc) > 300 else doc)
st.divider()
if 'total_found' in search_results:
st.success(f"โœ… {search_results['total_found']}๊ฐœ ๋ฌธ์„œ์—์„œ ์ƒ์œ„ 7๊ฐœ ์„ ํƒ")
reranked_results = rerank_results(
query_info['original'],
search_results,
st.session_state.embedder,
query_info['keywords']
)
answer = generate_answer(
query_info['original'],
reranked_results,
st.session_state.grok_key
)
st.markdown(answer)
highlight_texts = extract_highlight_texts(
reranked_results['documents'][0],
query_info['keywords']
)
highlights = []
for doc, meta in zip(reranked_results['documents'][0],
reranked_results['metadatas'][0]):
for text in highlight_texts:
if text in doc:
highlights.append({
'page': meta['page'],
'text': text
})
st.session_state.current_highlights = highlights
st.session_state.chat_history.append({
"role": "assistant",
"content": answer,
"sources": {
"docs": reranked_results['documents'][0],
"metas": reranked_results['metadatas'][0],
"scores": reranked_results.get('scores', []),
"keywords": query_info['keywords']
}
})
with st.expander("๐Ÿ“š ์ฐธ์กฐ ๋ฌธ์„œ", expanded=True):
for i, (doc, meta) in enumerate(zip(
reranked_results['documents'][0],
reranked_results['metadatas'][0]
), 1):
score = reranked_results.get('scores', [None])[i-1]
score_text = f" (๊ด€๋ จ๋„: {score:.2%})" if score else ""
st.markdown(f"""
<div class="source-box">
<div class="source-title">
<span class="page-indicator">ํŽ˜์ด์ง€ {meta['page']}</span>
{score_text}
</div>
<div style="font-size: 0.9rem; color: #475569;">
{doc[:300]}{'...' if len(doc) > 300 else ''}
</div>
</div>
""", unsafe_allow_html=True)
st.rerun()
except Exception as e:
st.error(f"์˜ค๋ฅ˜: {str(e)}")
import traceback
st.code(traceback.format_exc())
if __name__ == "__main__":
main()