Spaces:
Build error
Build error
| """ | |
| RFx ๋ฌธ์ ๋ถ์ AI ์์ด์ ํธ (PDF Text Highlighting) | |
| PDF ํ ์คํธ์ ์ง์ ํ์ด๋ผ์ดํธ ํ์ | |
| """ | |
| import streamlit as st | |
| import fitz # PyMuPDF | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer, util | |
| import requests | |
| import os | |
| import re | |
| import shutil | |
| from collections import Counter | |
| import numpy as np | |
| from typing import List, Dict, Tuple | |
| import base64 | |
| GROK_API_KEY = os.getenv("GROK_API_KEY") | |
| GROK_API_BASE = "https://api.x.ai/v1" | |
| CHROMA_DIR = "./chroma_db" | |
| EMBEDDING_MODEL = 'jhgan/ko-sroberta-multitask' | |
| st.set_page_config( | |
| page_title="RFx ๋ฌธ์ ๋ถ์ AI", | |
| page_icon="๐", | |
| layout="wide", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| st.markdown(""" | |
| <style> | |
| .main-title { | |
| font-size: 1.8rem; | |
| font-weight: bold; | |
| color: #1E3A8A; | |
| margin-bottom: 1rem; | |
| text-align: center; | |
| } | |
| .source-box { | |
| background: #F1F5F9; | |
| padding: 1rem; | |
| border-radius: 0.5rem; | |
| margin: 0.5rem 0; | |
| border-left: 3px solid #3B82F6; | |
| } | |
| .source-title { | |
| font-weight: bold; | |
| color: #1E40AF; | |
| margin-bottom: 0.5rem; | |
| } | |
| .keyword-badge { | |
| display: inline-block; | |
| background: #DBEAFE; | |
| color: #1E40AF; | |
| padding: 0.2rem 0.6rem; | |
| border-radius: 0.3rem; | |
| margin: 0.2rem; | |
| font-size: 0.85rem; | |
| } | |
| .pdf-container { | |
| border: 2px solid #E2E8F0; | |
| border-radius: 0.5rem; | |
| padding: 0.5rem; | |
| height: 800px; | |
| overflow-y: auto; | |
| background: white; | |
| } | |
| .page-indicator { | |
| background: #3B82F6; | |
| color: white; | |
| padding: 0.3rem 0.8rem; | |
| border-radius: 1rem; | |
| font-size: 0.85rem; | |
| display: inline-block; | |
| margin: 0.2rem; | |
| } | |
| .highlight-indicator { | |
| background: #FEF08A; | |
| color: #854D0E; | |
| padding: 0.5rem 1rem; | |
| border-radius: 0.5rem; | |
| margin: 0.5rem 0; | |
| font-weight: bold; | |
| border-left: 4px solid #EAB308; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def init_session(): | |
| if 'processed' not in st.session_state: | |
| st.session_state.processed = False | |
| if 'vector_db' not in st.session_state: | |
| st.session_state.vector_db = None | |
| if 'embedder' not in st.session_state: | |
| st.session_state.embedder = None | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if 'doc_metadata' not in st.session_state: | |
| st.session_state.doc_metadata = {} | |
| if 'pdf_bytes' not in st.session_state: | |
| st.session_state.pdf_bytes = None | |
| if 'pdf_pages_text' not in st.session_state: | |
| st.session_state.pdf_pages_text = {} | |
| if 'current_highlights' not in st.session_state: | |
| st.session_state.current_highlights = [] | |
| def extract_text_from_pdf(pdf_file) -> Tuple[List[str], List[Dict], bytes, Dict]: | |
| pdf_bytes = pdf_file.read() | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| chunks = [] | |
| metadata_list = [] | |
| pages_text = {} | |
| CHUNK_SIZE = 300 | |
| OVERLAP_SIZE = 60 | |
| for page_num in range(len(doc)): | |
| page = doc[page_num] | |
| text = page.get_text("text") | |
| pages_text[page_num + 1] = text | |
| if not text.strip(): | |
| continue | |
| lines = [line.strip() for line in text.split('\n') if line.strip()] | |
| cleaned_text = '\n'.join(lines) | |
| sentences = re.split(r'([.!?]\s+|\n{2,})', cleaned_text) | |
| sentences = [s for s in sentences if s.strip()] | |
| current_chunk = "" | |
| current_length = 0 | |
| for sentence in sentences: | |
| sentence_length = len(sentence) | |
| if current_length + sentence_length > CHUNK_SIZE and current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| metadata_list.append({ | |
| "page": page_num + 1, | |
| "source": pdf_file.name, | |
| "chunk_type": "paragraph" | |
| }) | |
| overlap_text = current_chunk[-OVERLAP_SIZE:] if len(current_chunk) > OVERLAP_SIZE else current_chunk | |
| current_chunk = overlap_text + sentence | |
| current_length = len(current_chunk) | |
| else: | |
| current_chunk += sentence | |
| current_length += sentence_length | |
| if current_chunk.strip(): | |
| chunks.append(current_chunk.strip()) | |
| metadata_list.append({ | |
| "page": page_num + 1, | |
| "source": pdf_file.name, | |
| "chunk_type": "paragraph" | |
| }) | |
| doc.close() | |
| return chunks, metadata_list, pdf_bytes, pages_text | |
| def load_embedding_model(): | |
| return SentenceTransformer(EMBEDDING_MODEL) | |
| def create_vector_db(chunks: List[str], metadata_list: List[Dict]): | |
| embedder = load_embedding_model() | |
| if os.path.exists(CHROMA_DIR): | |
| try: | |
| shutil.rmtree(CHROMA_DIR) | |
| except Exception: | |
| pass | |
| client = chromadb.PersistentClient( | |
| path=CHROMA_DIR, | |
| settings=chromadb.Settings( | |
| anonymized_telemetry=False, | |
| allow_reset=True, | |
| is_persistent=True | |
| ) | |
| ) | |
| try: | |
| collection = client.get_or_create_collection( | |
| name="rfx_docs", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| except Exception: | |
| try: | |
| client.delete_collection("rfx_docs") | |
| except Exception: | |
| pass | |
| collection = client.create_collection( | |
| name="rfx_docs", | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| batch_size = 32 | |
| all_embeddings = [] | |
| for i in range(0, len(chunks), batch_size): | |
| batch = chunks[i:i + batch_size] | |
| embeddings = embedder.encode(batch, show_progress_bar=False, convert_to_numpy=True) | |
| all_embeddings.extend(embeddings) | |
| ids = [f"doc_{i}" for i in range(len(chunks))] | |
| collection.add( | |
| embeddings=[emb.tolist() for emb in all_embeddings], | |
| documents=chunks, | |
| metadatas=metadata_list, | |
| ids=ids | |
| ) | |
| return collection, embedder | |
| def extract_keywords(text: str, top_n: int = 5) -> List[str]: | |
| words_with_numbers = re.findall(r'[๊ฐ-ํฃ]*\d+[๊ฐ-ํฃ]*', text) | |
| words = re.findall(r'[๊ฐ-ํฃ]{2,}', text) | |
| stopwords = { | |
| '๊ฒ', '๋ฑ', '๋ฐ', '๊ทธ', '์ด', '์ ', '์', '๋', '์ค', '๋ด', '๋ ', '์', '์ผ', | |
| '๊ฒฝ์ฐ', '๋ํ', 'ํตํด', '์ํด', '๊ด๋ จ', '์๋', 'ํ๋', '๋๋', '์ด๋ฐ', '์ ๋ฐ', | |
| '์ด๋ค', '๋ฌด์จ', '์ด๋', '๋๊ตฌ', '์ธ์ ', '์ด๋', '๋ฌด์', '์ด๋ป๊ฒ', '์', | |
| '์๋ ค', '์ค๋ช ', '๋งํด', '๋ํด', '๊ดํ์ฌ', '์๋์', '์ธ๊ฐ์', '๋ฌด์์ธ๊ฐ์', | |
| '์ผ๋ง', '์ ๋๊น', 'ํฉ๋๊น' | |
| } | |
| important_keywords = { | |
| '๊ธ์ก', '๊ฐ๊ฒฉ', '๋น์ฉ', '์์ฐ', '์ค๊ณ', '์ฌ์ ', '๊ณผ์ ', '๊ณ์ฝ', | |
| '๊ณต์ฌ', '์ฉ์ญ', '์ ์', '์ ์ฐฐ', '๋์ฐฐ', '๊ฒฌ์ ', '๋จ๊ฐ' | |
| } | |
| filtered_words = [w for w in words if w not in stopwords and len(w) >= 2] | |
| word_freq = Counter(filtered_words) | |
| for word in word_freq: | |
| if word in important_keywords: | |
| word_freq[word] += 5 | |
| result = [] | |
| result.extend([w for w in words_with_numbers if w]) | |
| for word, _ in word_freq.most_common(top_n * 2): | |
| if word not in result: | |
| result.append(word) | |
| if len(result) >= top_n: | |
| break | |
| return result[:top_n] | |
| def rewrite_query(query: str) -> Dict[str, any]: | |
| original = query.strip() | |
| cleaned = re.sub(r'[?!,.~]', '', original) | |
| keywords = extract_keywords(cleaned, top_n=7) | |
| variations = [] | |
| variations.append(original) | |
| if keywords: | |
| if len(keywords) >= 2: | |
| variations.append(' '.join(keywords[:2])) | |
| if len(keywords) >= 3: | |
| variations.append(' '.join(keywords[:3])) | |
| for kw in keywords[:3]: | |
| if kw not in variations: | |
| variations.append(kw) | |
| synonym_map = { | |
| '๊ธ์ก': ['๊ฐ๊ฒฉ', '๋น์ฉ', '์์ฐ'], | |
| '์ค๊ณ': ['๋์์ธ', '๊ณํ'], | |
| '์ฌ์ ': ['ํ๋ก์ ํธ', '๊ณผ์ '], | |
| } | |
| for keyword in keywords[:2]: | |
| if keyword in synonym_map: | |
| for syn in synonym_map[keyword]: | |
| combined = original.replace(keyword, syn) | |
| if combined not in variations: | |
| variations.append(combined) | |
| break | |
| seen = set() | |
| unique_variations = [] | |
| for v in variations: | |
| if v not in seen and v.strip(): | |
| seen.add(v) | |
| unique_variations.append(v) | |
| return { | |
| 'original': original, | |
| 'cleaned': cleaned, | |
| 'keywords': keywords, | |
| 'variations': unique_variations[:7] | |
| } | |
| def search_with_multiple_queries(queries: List[str], collection, embedder, top_k: int = 5) -> Dict: | |
| all_results = [] | |
| seen_ids = set() | |
| for query in queries: | |
| query_embedding = embedder.encode([query], convert_to_numpy=True)[0] | |
| results = collection.query( | |
| query_embeddings=[query_embedding.tolist()], | |
| n_results=min(top_k * 5, 30), | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| for i, doc_id in enumerate(results['ids'][0]): | |
| if doc_id not in seen_ids: | |
| seen_ids.add(doc_id) | |
| all_results.append({ | |
| 'id': doc_id, | |
| 'document': results['documents'][0][i], | |
| 'metadata': results['metadatas'][0][i], | |
| 'distance': results['distances'][0][i], | |
| 'query': query | |
| }) | |
| all_results.sort(key=lambda x: x['distance']) | |
| top_results = all_results[:top_k] | |
| return { | |
| 'documents': [[r['document'] for r in top_results]], | |
| 'metadatas': [[r['metadata'] for r in top_results]], | |
| 'distances': [[r['distance'] for r in top_results]], | |
| 'queries_used': queries, | |
| 'total_found': len(all_results) | |
| } | |
| def rerank_results(query: str, search_results: Dict, embedder, keywords: List[str]) -> Dict: | |
| docs = search_results['documents'][0] | |
| metas = search_results['metadatas'][0] | |
| distances = search_results['distances'][0] | |
| if not docs: | |
| return { | |
| 'documents': [[]], | |
| 'metadatas': [[]], | |
| 'distances': [[]], | |
| 'scores': [] | |
| } | |
| query_embedding = embedder.encode([query], convert_to_numpy=True)[0] | |
| doc_embeddings = embedder.encode(docs, convert_to_numpy=True) | |
| similarities = util.cos_sim(query_embedding, doc_embeddings)[0].cpu().numpy() | |
| keyword_scores = [] | |
| for doc in docs: | |
| doc_lower = doc.lower() | |
| score = sum(1 for kw in keywords if kw.lower() in doc_lower) | |
| keyword_scores.append(score) | |
| if max(keyword_scores) > 0: | |
| keyword_scores = [s / max(keyword_scores) for s in keyword_scores] | |
| numeric_query_terms = ['๊ธ์ก', '์์ฐ', '๊ฐ๊ฒฉ', '๋น์ฉ', '๋จ๊ฐ'] | |
| is_numeric_query = any(term in query for term in numeric_query_terms) | |
| if is_numeric_query: | |
| money_patterns = [ | |
| r'\d{1,3}(?:,\d{3})+์', | |
| r'\d+๋ง์', | |
| r'\d+์ต์', | |
| r'\(์ผ๊ธ\s*[^)]+\)' | |
| ] | |
| numeric_scores = [] | |
| for doc in docs: | |
| score = 0 | |
| for pattern in money_patterns: | |
| if re.search(pattern, doc): | |
| score = 1 | |
| break | |
| numeric_scores.append(score) | |
| if max(numeric_scores) > 0: | |
| numeric_scores = [s / max(numeric_scores) for s in numeric_scores] | |
| else: | |
| numeric_scores = [0.0 for _ in numeric_scores] | |
| final_scores = [ | |
| 0.6 * sim + 0.25 * kw + 0.15 * num | |
| for sim, kw, num in zip(similarities, keyword_scores, numeric_scores) | |
| ] | |
| else: | |
| final_scores = [0.7 * sim + 0.3 * kw for sim, kw in zip(similarities, keyword_scores)] | |
| ranked_indices = np.argsort(final_scores)[::-1] | |
| return { | |
| 'documents': [[docs[i] for i in ranked_indices]], | |
| 'metadatas': [[metas[i] for i in ranked_indices]], | |
| 'distances': [[distances[i] for i in ranked_indices]], | |
| 'scores': [final_scores[i] for i in ranked_indices] | |
| } | |
| def build_context(search_results: Dict, max_length: int = 3000) -> str: | |
| context_parts = [] | |
| current_length = 0 | |
| docs = search_results['documents'][0] | |
| metas = search_results['metadatas'][0] | |
| for i, (doc, meta) in enumerate(zip(docs, metas), 1): | |
| part = f"[๋ฌธ์ {i}] (ํ์ด์ง {meta['page']})\n{doc}\n" | |
| part_length = len(part) | |
| if current_length + part_length > max_length: | |
| remaining = max_length - current_length | |
| if remaining > 200: | |
| part = f"[๋ฌธ์ {i}] (ํ์ด์ง {meta['page']})\n{doc[:remaining-50]}...\n" | |
| context_parts.append(part) | |
| break | |
| context_parts.append(part) | |
| current_length += part_length | |
| return "\n".join(context_parts) | |
| def generate_answer(query: str, search_results: Dict, api_key: str) -> str: | |
| context = build_context(search_results, max_length=4000) | |
| system_prompt = """๋น์ ์ RFx ๋ฌธ์ ์ ๋ฌธ ๋ถ์๊ฐ์ ๋๋ค. | |
| **์ค์ ์์น:** | |
| 1. ์ ๊ณต๋ ๋ฌธ์๋ฅผ **๋งค์ฐ ๊ผผ๊ผผํ** ์ฝ๊ณ ์ ํํ ์ ๋ณด๋ฅผ ์ฐพ์ผ์ธ์ | |
| 2. ์ซ์, ๊ธ์ก, ๋ ์ง ๋ฑ ๊ตฌ์ฒด์ ์ธ ์ ๋ณด๋ฅผ ์ฐ์ ์ ์ผ๋ก ์ฐพ์ผ์ธ์ | |
| 3. ๋ฌธ์์ ์ ๋ณด๊ฐ ์๋๋ฐ๋ "์๋ค"๊ณ ํ์ง ๋ง์ธ์ | |
| 4. ๋ต๋ณ ์ ๋ฐ๋์ [๋ฌธ์ N, ํ์ด์ง X] ํํ๋ก ์ถ์ฒ ๋ช ์ | |
| 5. ์ ๋งคํ ํํ ๋์ ๊ตฌ์ฒด์ ์ธ ์์น๋ฅผ ์ ๊ณตํ์ธ์ | |
| **๋ต๋ณ ํ์:** | |
| - ํต์ฌ ๋ต๋ณ์ ๋จผ์ ๋ช ํํ๊ฒ ์ ์ | |
| - ์ถ์ฒ ๋ช ์ (ํ์ด์ง ๋ฒํธ ํฌํจ) | |
| - ํ์์ ์ถ๊ฐ ๊ด๋ จ ์ ๋ณด ์ ๊ณต""" | |
| user_prompt = f"""๋ค์ ๋ฌธ์๋ค์ **๋งค์ฐ ๊ผผ๊ผผํ** ์ฝ๊ณ ์ง๋ฌธ์ ๋ต๋ณํ์ธ์. | |
| <๋ฌธ์> | |
| {context} | |
| </๋ฌธ์> | |
| <์ง๋ฌธ> | |
| {query} | |
| </์ง๋ฌธ> | |
| **์ค์**: | |
| - ๋ฌธ์๋ฅผ ์ฒ์๋ถํฐ ๋๊น์ง ์ฃผ์ ๊น๊ฒ ์ฝ์ผ์ธ์ | |
| - ์ซ์, ๊ธ์ก ๋ฑ ๊ตฌ์ฒด์ ์ธ ์ ๋ณด๋ฅผ ์ฐพ์ผ์ธ์ | |
| - ์ฐพ์ ์ ๋ณด๋ ์ ํํ ์ธ์ฉํ์ธ์ | |
| - ์ ๋ง๋ก ๋ฌธ์์ ์๋ ๊ฒฝ์ฐ์๋ง "์ฐพ์ ์ ์์ต๋๋ค"๋ผ๊ณ ํ์ธ์""" | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {api_key}" | |
| } | |
| payload = { | |
| "model": "grok-3", | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| "temperature": 0.1, | |
| "max_tokens": 2000, | |
| "stream": False | |
| } | |
| try: | |
| response = requests.post( | |
| f"{GROK_API_BASE}/chat/completions", | |
| headers=headers, | |
| json=payload, | |
| timeout=30 | |
| ) | |
| if response.status_code != 200: | |
| error_detail = "" | |
| try: | |
| error_data = response.json() | |
| error_detail = error_data.get('error', {}).get('message', '') | |
| except Exception: | |
| error_detail = response.text | |
| return f"โ API ์ค๋ฅ (์ฝ๋: {response.status_code})\n\n{error_detail}" | |
| result = response.json() | |
| return result["choices"][0]["message"]["content"] | |
| except Exception as e: | |
| return f"โ ์ค๋ฅ: {str(e)}" | |
| def highlight_text_in_pdf(pdf_bytes: bytes, highlight_info: List[Dict]) -> bytes: | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| for item in highlight_info: | |
| page_num = item['page'] - 1 | |
| search_text = item['text'] | |
| if page_num >= len(doc): | |
| continue | |
| page = doc[page_num] | |
| text_variations = [ | |
| search_text, | |
| search_text.replace(' ', ''), | |
| search_text.replace(',', ''), | |
| ] | |
| for text_var in text_variations: | |
| text_instances = page.search_for(text_var) | |
| for inst in text_instances: | |
| highlight = page.add_highlight_annot(inst) | |
| highlight.set_colors(stroke=[1, 1, 0]) | |
| highlight.update() | |
| output_bytes = doc.tobytes() | |
| doc.close() | |
| return output_bytes | |
| def extract_highlight_texts(documents: List[str], keywords: List[str]) -> List[str]: | |
| highlight_texts = [] | |
| for doc in documents: | |
| money_patterns = [ | |
| r'\d{1,3}(?:,\d{3})+์', | |
| r'\d+๋ง์', | |
| r'\d+์ต์', | |
| r'\(์ผ๊ธ\s*[^)]+\)', | |
| ] | |
| for pattern in money_patterns: | |
| matches = re.findall(pattern, doc) | |
| highlight_texts.extend(matches) | |
| date_patterns = [ | |
| r'\d{4}[๋ .]\d{1,2}[์.]\d{1,2}์ผ?', | |
| r'\d{2}\.\d{2}\.\d{2}', | |
| ] | |
| for pattern in date_patterns: | |
| matches = re.findall(pattern, doc) | |
| highlight_texts.extend(matches) | |
| for keyword in keywords: | |
| if keyword in doc: | |
| sentences = re.split(r'[.!?]\s+', doc) | |
| for sent in sentences: | |
| if keyword in sent and len(sent) < 100: | |
| highlight_texts.append(sent.strip()) | |
| unique_texts = list(set(highlight_texts)) | |
| unique_texts.sort(key=len) | |
| return unique_texts[:10] | |
| def render_pdf_with_highlights(pdf_bytes: bytes, highlight_info: List[Dict]): | |
| highlighted_pdf = highlight_text_in_pdf(pdf_bytes, highlight_info) | |
| doc = fitz.open(stream=highlighted_pdf, filetype="pdf") | |
| highlighted_pages = set(h['page'] for h in highlight_info) | |
| pdf_html = '<div class="pdf-container">' | |
| for page_num in range(len(doc)): | |
| page = doc[page_num] | |
| pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) | |
| img_data = pix.tobytes("png") | |
| img_base64 = base64.b64encode(img_data).decode() | |
| pdf_html += '<div style="margin-bottom: 2rem; position: relative;">' | |
| pdf_html += f'<div style="background: #3B82F6; color: white; padding: 0.5rem; margin-bottom: 0.5rem; border-radius: 0.3rem; font-weight: bold;">๐ ํ์ด์ง {page_num + 1}</div>' | |
| if (page_num + 1) in highlighted_pages: | |
| page_highlights = [h for h in highlight_info if h['page'] == page_num + 1] | |
| highlight_texts = ', '.join([f'"{h["text"][:30]}..."' for h in page_highlights[:3]]) | |
| pdf_html += f'<div class="highlight-indicator">โญ ํ์ด๋ผ์ดํธ: {highlight_texts}</div>' | |
| pdf_html += f'<img src="data:image/png;base64,{img_base64}" style="width: 100%; border: 1px solid #E2E8F0; border-radius: 0.3rem; box-shadow: 0 1px 3px rgba(0,0,0,0.1);" />' | |
| pdf_html += '</div>' | |
| pdf_html += '</div>' | |
| doc.close() | |
| return pdf_html | |
| def main(): | |
| init_session() | |
| st.markdown('<div class="main-title">๐ RFx ๋ฌธ์ ๋ถ์ AI ์์ด์ ํธ</div>', unsafe_allow_html=True) | |
| with st.sidebar: | |
| st.header("โ๏ธ ์ค์ ") | |
| grok_key = st.text_input("Grok API Key", value=GROK_API_KEY or "", type="password") | |
| if grok_key: | |
| os.environ["GROK_API_KEY"] = grok_key | |
| st.session_state.grok_key = grok_key | |
| st.divider() | |
| if st.button("๐ ๋ฐ์ดํฐ๋ฒ ์ด์ค ์ด๊ธฐํ", help="ChromaDB ์ค๋ฅ ๋ฐ์ ์ ํด๋ฆญ"): | |
| if os.path.exists(CHROMA_DIR): | |
| try: | |
| shutil.rmtree(CHROMA_DIR) | |
| st.success("โ ๋ฐ์ดํฐ๋ฒ ์ด์ค ์ด๊ธฐํ ์๋ฃ!") | |
| st.session_state.processed = False | |
| st.session_state.vector_db = None | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"์ด๊ธฐํ ์คํจ: {str(e)}") | |
| st.divider() | |
| st.subheader("๐ค ๋ฌธ์ ์ ๋ก๋") | |
| uploaded_file = st.file_uploader("PDF ํ์ผ ์ ํ", type=['pdf']) | |
| if uploaded_file: | |
| if st.button("๐ ๋ฌธ์ ์ฒ๋ฆฌ", type="primary", disabled=st.session_state.get('processing', False)): | |
| if not grok_key: | |
| st.error("โ ๏ธ Grok API ํค๋ฅผ ์ ๋ ฅํ์ธ์!") | |
| return | |
| st.session_state.processing = True | |
| with st.spinner("๐ ๋ฌธ์ ์ฒ๋ฆฌ ์ค..."): | |
| try: | |
| chunks, metadata_list, pdf_bytes, pages_text = extract_text_from_pdf(uploaded_file) | |
| st.info(f"๐ {len(chunks)}๊ฐ ์ฒญํฌ ์ถ์ถ ์๋ฃ") | |
| with st.expander("๐ ์ถ์ถ๋ ํ ์คํธ ์ํ", expanded=False): | |
| if chunks: | |
| st.text(f"์ฒซ ๋ฒ์งธ ์ฒญํฌ (์ด {len(chunks[0])}์):") | |
| st.code(chunks[0][:500] + "..." if len(chunks[0]) > 500 else chunks[0]) | |
| with st.spinner("๐ง ๋ฒกํฐ ๋ฐ์ดํฐ๋ฒ ์ด์ค ์์ฑ ์ค..."): | |
| collection, embedder = create_vector_db(chunks, metadata_list) | |
| st.session_state.vector_db = collection | |
| st.session_state.embedder = embedder | |
| st.session_state.pdf_bytes = pdf_bytes | |
| st.session_state.pdf_pages_text = pages_text | |
| st.session_state.processed = True | |
| st.session_state.doc_metadata = { | |
| "filename": uploaded_file.name, | |
| "chunks": len(chunks), | |
| "pages": len(set(m['page'] for m in metadata_list)) | |
| } | |
| st.success("โ ๋ฌธ์ ์ฒ๋ฆฌ ์๋ฃ!") | |
| except Exception as e: | |
| st.error(f"์ค๋ฅ: {str(e)}") | |
| finally: | |
| st.session_state.processing = False | |
| st.divider() | |
| if st.session_state.processed: | |
| st.subheader("๐ ๋ฌธ์ ์ ๋ณด") | |
| meta = st.session_state.doc_metadata | |
| st.write(f"**ํ์ผ๋ช :** {meta['filename']}") | |
| st.write(f"**ํ์ด์ง:** {meta['pages']}ํ์ด์ง") | |
| st.write(f"**์ฒญํฌ:** {meta['chunks']}๊ฐ") | |
| if st.button("๐๏ธ ์ฑํ ์ด๊ธฐํ"): | |
| st.session_state.chat_history = [] | |
| st.session_state.current_highlights = [] | |
| st.rerun() | |
| if not st.session_state.processed: | |
| st.info("๐ ์ผ์ชฝ ์ฌ์ด๋๋ฐ์์ PDF ๋ฌธ์๋ฅผ ์ ๋ก๋ํ์ธ์") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.markdown("### ๐ PDF ๋ทฐ์ด\n์๋ณธ ๋ฌธ์ ํ์ธ") | |
| with col2: | |
| st.markdown("### ๐จ ํ์ด๋ผ์ดํธ\nํต์ฌ ๋ด์ฉ ๊ฐ์กฐ") | |
| with col3: | |
| st.markdown("### ๐ฌ AI ์ฑ๋ด\n์ ํํ ๋ต๋ณ") | |
| else: | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.markdown("### ๐ ๋ฌธ์ ๋ทฐ์ด") | |
| if st.session_state.pdf_bytes: | |
| pdf_html = render_pdf_with_highlights( | |
| st.session_state.pdf_bytes, | |
| st.session_state.current_highlights | |
| ) | |
| st.markdown(pdf_html, unsafe_allow_html=True) | |
| with col2: | |
| st.markdown("### ๐ฌ AI ์ฑ๋ด") | |
| chat_container = st.container() | |
| with chat_container: | |
| for msg in st.session_state.chat_history: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| if msg["role"] == "assistant" and "sources" in msg: | |
| with st.expander("๐ ์ฐธ์กฐ ๋ฌธ์"): | |
| for i, (doc, meta) in enumerate(zip( | |
| msg["sources"]["docs"], | |
| msg["sources"]["metas"] | |
| ), 1): | |
| score = msg["sources"]["scores"][i-1] if "scores" in msg["sources"] else None | |
| score_text = f" (๊ด๋ จ๋: {score:.2%})" if score else "" | |
| st.markdown(f""" | |
| <div class="source-box"> | |
| <div class="source-title"> | |
| <span class="page-indicator">ํ์ด์ง {meta['page']}</span> | |
| {score_text} | |
| </div> | |
| <div style="font-size: 0.9rem; color: #475569;"> | |
| {doc[:300]}{'...' if len(doc) > 300 else ''} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if prompt := st.chat_input("์ง๋ฌธ์ ์ ๋ ฅํ์ธ์...", disabled=st.session_state.get('processing', False)): | |
| if not st.session_state.get('grok_key'): | |
| st.error("โ ๏ธ Grok API ํค๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์!") | |
| return | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| st.session_state.chat_history.append({"role": "user", "content": prompt}) | |
| with st.chat_message("assistant"): | |
| with st.spinner("๐ ๊ฒ์ ๋ฐ ๋ถ์ ์ค..."): | |
| try: | |
| query_info = rewrite_query(prompt) | |
| with st.expander("๐ ๊ฒ์ ๋๋ฒ๊ทธ ์ ๋ณด", expanded=False): | |
| st.write("**์ถ์ถ๋ ํค์๋:**", query_info['keywords']) | |
| st.write("**๊ฒ์ ์ฟผ๋ฆฌ ๋ณํ:**", query_info['variations']) | |
| search_results = search_with_multiple_queries( | |
| query_info['variations'], | |
| st.session_state.vector_db, | |
| st.session_state.embedder, | |
| top_k=7 | |
| ) | |
| with st.expander("๐ ๊ฒ์๋ ๋ฌธ์ ๋ด์ฉ", expanded=False): | |
| st.write(f"**์ด {search_results.get('total_found', 0)}๊ฐ ๋ฌธ์ ๋ฐ๊ฒฌ**") | |
| for i, doc in enumerate(search_results['documents'][0][:3], 1): | |
| st.write(f"**๋ฌธ์ {i}:**") | |
| st.text(doc[:300] + "..." if len(doc) > 300 else doc) | |
| st.divider() | |
| if 'total_found' in search_results: | |
| st.success(f"โ {search_results['total_found']}๊ฐ ๋ฌธ์์์ ์์ 7๊ฐ ์ ํ") | |
| reranked_results = rerank_results( | |
| query_info['original'], | |
| search_results, | |
| st.session_state.embedder, | |
| query_info['keywords'] | |
| ) | |
| answer = generate_answer( | |
| query_info['original'], | |
| reranked_results, | |
| st.session_state.grok_key | |
| ) | |
| st.markdown(answer) | |
| highlight_texts = extract_highlight_texts( | |
| reranked_results['documents'][0], | |
| query_info['keywords'] | |
| ) | |
| highlights = [] | |
| for doc, meta in zip(reranked_results['documents'][0], | |
| reranked_results['metadatas'][0]): | |
| for text in highlight_texts: | |
| if text in doc: | |
| highlights.append({ | |
| 'page': meta['page'], | |
| 'text': text | |
| }) | |
| st.session_state.current_highlights = highlights | |
| st.session_state.chat_history.append({ | |
| "role": "assistant", | |
| "content": answer, | |
| "sources": { | |
| "docs": reranked_results['documents'][0], | |
| "metas": reranked_results['metadatas'][0], | |
| "scores": reranked_results.get('scores', []), | |
| "keywords": query_info['keywords'] | |
| } | |
| }) | |
| with st.expander("๐ ์ฐธ์กฐ ๋ฌธ์", expanded=True): | |
| for i, (doc, meta) in enumerate(zip( | |
| reranked_results['documents'][0], | |
| reranked_results['metadatas'][0] | |
| ), 1): | |
| score = reranked_results.get('scores', [None])[i-1] | |
| score_text = f" (๊ด๋ จ๋: {score:.2%})" if score else "" | |
| st.markdown(f""" | |
| <div class="source-box"> | |
| <div class="source-title"> | |
| <span class="page-indicator">ํ์ด์ง {meta['page']}</span> | |
| {score_text} | |
| </div> | |
| <div style="font-size: 0.9rem; color: #475569;"> | |
| {doc[:300]}{'...' if len(doc) > 300 else ''} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"์ค๋ฅ: {str(e)}") | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| if __name__ == "__main__": | |
| main() |