Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,15 +1,31 @@
|
|
| 1 |
import os
|
| 2 |
-
|
|
|
|
| 3 |
os.environ["TORCH_DYNAMO_DISABLE"] = "1"
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import torch
|
| 6 |
-
#
|
| 7 |
torch.set_float32_matmul_precision('high')
|
|
|
|
|
|
|
| 8 |
import torch._inductor
|
| 9 |
torch._inductor.config.triton.cudagraphs = False
|
|
|
|
|
|
|
| 10 |
import torch._dynamo
|
|
|
|
|
|
|
| 11 |
import gradio as gr
|
| 12 |
import spaces
|
|
|
|
| 13 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 14 |
|
| 15 |
from threading import Thread
|
|
@@ -24,16 +40,14 @@ from datetime import datetime
|
|
| 24 |
import pyarrow.parquet as pq
|
| 25 |
import pypdf
|
| 26 |
import io
|
| 27 |
-
import pyarrow.parquet as pq
|
| 28 |
-
from tabulate import tabulate
|
| 29 |
import platform
|
| 30 |
import subprocess
|
| 31 |
import pytesseract
|
| 32 |
from pdf2image import convert_from_path
|
| 33 |
-
import queue #
|
| 34 |
-
import time
|
| 35 |
|
| 36 |
-
# --------------------
|
| 37 |
try:
|
| 38 |
import re
|
| 39 |
import requests
|
|
@@ -50,9 +64,6 @@ except ModuleNotFoundError as e:
|
|
| 50 |
)
|
| 51 |
# ---------------------------------------------------------------------------
|
| 52 |
|
| 53 |
-
# 1) Dynamo suppress_errors 옵션 사용 (오류 시 eager로 fallback)
|
| 54 |
-
torch._dynamo.config.suppress_errors = True
|
| 55 |
-
|
| 56 |
# 전역 변수
|
| 57 |
current_file_context = None
|
| 58 |
|
|
@@ -62,21 +73,21 @@ MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
|
|
| 62 |
MODELS = os.environ.get("MODELS")
|
| 63 |
MODEL_NAME = MODEL_ID.split("/")[-1]
|
| 64 |
|
| 65 |
-
model = None #
|
| 66 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 67 |
|
| 68 |
-
# 위키피디아 데이터셋 로드
|
| 69 |
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
|
| 70 |
print("Wikipedia dataset loaded:", wiki_dataset)
|
| 71 |
|
| 72 |
-
# TF-IDF 벡터라이저 초기화 및 학습
|
| 73 |
print("TF-IDF 벡터화 시작...")
|
| 74 |
questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 사용
|
| 75 |
vectorizer = TfidfVectorizer(max_features=1000)
|
| 76 |
question_vectors = vectorizer.fit_transform(questions)
|
| 77 |
print("TF-IDF 벡터화 완료")
|
| 78 |
|
| 79 |
-
|
| 80 |
class ChatHistory:
|
| 81 |
def __init__(self):
|
| 82 |
self.history = []
|
|
@@ -132,19 +143,18 @@ class ChatHistory:
|
|
| 132 |
print(f"히스토리 로드 실패: {e}")
|
| 133 |
self.history = []
|
| 134 |
|
| 135 |
-
|
| 136 |
-
# 전역 ChatHistory 인스턴스 생성
|
| 137 |
chat_history = ChatHistory()
|
| 138 |
|
| 139 |
-
|
| 140 |
def find_relevant_context(query, top_k=3):
|
| 141 |
# 쿼리 벡터화
|
| 142 |
query_vector = vectorizer.transform([query])
|
| 143 |
-
# 코사인 유사도
|
| 144 |
similarities = (query_vector * question_vectors.T).toarray()[0]
|
| 145 |
-
#
|
| 146 |
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
| 147 |
-
|
| 148 |
relevant_contexts = []
|
| 149 |
for idx in top_indices:
|
| 150 |
if similarities[idx] > 0:
|
|
@@ -155,16 +165,14 @@ def find_relevant_context(query, top_k=3):
|
|
| 155 |
})
|
| 156 |
return relevant_contexts
|
| 157 |
|
| 158 |
-
|
| 159 |
def init_msg():
|
| 160 |
return "파일을 분석하고 있습니다..."
|
| 161 |
|
| 162 |
-
|
| 163 |
# -------------------- PDF 파일을 Markdown으로 변환하는 유틸 함수들 --------------------
|
| 164 |
def extract_text_from_pdf(reader: PdfReader) -> str:
|
| 165 |
"""
|
| 166 |
PyPDF를 사용해 모든 페이지 텍스트를 추출.
|
| 167 |
-
만약 텍스트가 없으면 빈 문자열 반환.
|
| 168 |
"""
|
| 169 |
full_text = ""
|
| 170 |
for idx, page in enumerate(reader.pages):
|
|
@@ -173,20 +181,17 @@ def extract_text_from_pdf(reader: PdfReader) -> str:
|
|
| 173 |
full_text += f"---- Page {idx+1} ----\n" + text + "\n\n"
|
| 174 |
return full_text.strip()
|
| 175 |
|
| 176 |
-
|
| 177 |
def convert_pdf_to_markdown(pdf_file: str):
|
| 178 |
"""
|
| 179 |
-
PDF
|
| 180 |
-
이미지가 많고 텍스트가
|
| 181 |
-
최종적으로 Markdown 형식으로 변환 가능한 텍스트를 반환한다.
|
| 182 |
-
메타데이터도 함께 반환.
|
| 183 |
"""
|
| 184 |
try:
|
| 185 |
reader = PdfReader(pdf_file)
|
| 186 |
except Exception as e:
|
| 187 |
return f"PDF 파일을 읽는 중 오류 발생: {e}", None, None
|
| 188 |
|
| 189 |
-
#
|
| 190 |
raw_meta = reader.metadata
|
| 191 |
metadata = {
|
| 192 |
"author": raw_meta.author if raw_meta else None,
|
|
@@ -196,19 +201,16 @@ def convert_pdf_to_markdown(pdf_file: str):
|
|
| 196 |
"title": raw_meta.title if raw_meta else None,
|
| 197 |
}
|
| 198 |
|
| 199 |
-
#
|
| 200 |
full_text = extract_text_from_pdf(reader)
|
| 201 |
|
| 202 |
-
#
|
| 203 |
-
image_count =
|
| 204 |
-
for page in reader.pages:
|
| 205 |
-
image_count += len(page.images)
|
| 206 |
-
|
| 207 |
if image_count > 0 and len(full_text) < 1000:
|
| 208 |
try:
|
| 209 |
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
|
| 210 |
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
|
| 211 |
-
#
|
| 212 |
reader_ocr = PdfReader(out_pdf_file)
|
| 213 |
full_text = extract_text_from_pdf(reader_ocr)
|
| 214 |
except Exception as e:
|
|
@@ -216,11 +218,9 @@ def convert_pdf_to_markdown(pdf_file: str):
|
|
| 216 |
|
| 217 |
return full_text, metadata, pdf_file
|
| 218 |
|
| 219 |
-
|
| 220 |
-
# ---------------------------------------------------------------------------
|
| 221 |
-
|
| 222 |
def analyze_file_content(content, file_type):
|
| 223 |
-
"""
|
| 224 |
if file_type in ['parquet', 'csv']:
|
| 225 |
try:
|
| 226 |
lines = content.split('\n')
|
|
@@ -245,15 +245,13 @@ def analyze_file_content(content, file_type):
|
|
| 245 |
words = len(content.split())
|
| 246 |
return f"📝 Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"
|
| 247 |
|
| 248 |
-
|
| 249 |
def read_uploaded_file(file):
|
| 250 |
"""
|
| 251 |
-
업로드된
|
| 252 |
-
1) 파일 타입별로 내용을 읽고
|
| 253 |
-
2) 분석 결과와 함께 반환
|
| 254 |
"""
|
| 255 |
if file is None:
|
| 256 |
return "", ""
|
|
|
|
| 257 |
try:
|
| 258 |
file_ext = os.path.splitext(file.name)[1].lower()
|
| 259 |
|
|
@@ -267,7 +265,8 @@ def read_uploaded_file(file):
|
|
| 267 |
content += f"1. Basic Information:\n"
|
| 268 |
content += f"- Total Rows: {len(df):,}\n"
|
| 269 |
content += f"- Total Columns: {len(df.columns)}\n"
|
| 270 |
-
|
|
|
|
| 271 |
|
| 272 |
content += f"2. Column Information:\n"
|
| 273 |
for col in df.columns:
|
|
@@ -279,7 +278,8 @@ def read_uploaded_file(file):
|
|
| 279 |
content += f"\n\n4. Missing Values:\n"
|
| 280 |
null_counts = df.isnull().sum()
|
| 281 |
for col, count in null_counts[null_counts > 0].items():
|
| 282 |
-
|
|
|
|
| 283 |
|
| 284 |
numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
|
| 285 |
if len(numeric_cols) > 0:
|
|
@@ -291,7 +291,7 @@ def read_uploaded_file(file):
|
|
| 291 |
except Exception as e:
|
| 292 |
return f"Error reading Parquet file: {str(e)}", "error"
|
| 293 |
|
| 294 |
-
# PDF
|
| 295 |
if file_ext == '.pdf':
|
| 296 |
try:
|
| 297 |
markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
|
|
@@ -305,7 +305,6 @@ def read_uploaded_file(file):
|
|
| 305 |
|
| 306 |
content += "## Extracted Text\n\n"
|
| 307 |
content += markdown_text
|
| 308 |
-
|
| 309 |
return content, "pdf"
|
| 310 |
except Exception as e:
|
| 311 |
return f"Error reading PDF file: {str(e)}", "error"
|
|
@@ -320,7 +319,8 @@ def read_uploaded_file(file):
|
|
| 320 |
content += f"1. Basic Information:\n"
|
| 321 |
content += f"- Total Rows: {len(df):,}\n"
|
| 322 |
content += f"- Total Columns: {len(df.columns)}\n"
|
| 323 |
-
|
|
|
|
| 324 |
|
| 325 |
content += f"2. Column Information:\n"
|
| 326 |
for col in df.columns:
|
|
@@ -332,14 +332,17 @@ def read_uploaded_file(file):
|
|
| 332 |
content += f"\n\n4. Missing Values:\n"
|
| 333 |
null_counts = df.isnull().sum()
|
| 334 |
for col, count in null_counts[null_counts > 0].items():
|
| 335 |
-
|
|
|
|
| 336 |
|
| 337 |
return content, "csv"
|
| 338 |
except UnicodeDecodeError:
|
| 339 |
continue
|
| 340 |
-
raise UnicodeDecodeError(
|
|
|
|
|
|
|
| 341 |
|
| 342 |
-
#
|
| 343 |
else:
|
| 344 |
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
|
| 345 |
for encoding in encodings:
|
|
@@ -350,15 +353,19 @@ def read_uploaded_file(file):
|
|
| 350 |
lines = content.split('\n')
|
| 351 |
total_lines = len(lines)
|
| 352 |
non_empty_lines = len([line for line in lines if line.strip()])
|
| 353 |
-
|
| 354 |
-
|
|
|
|
|
|
|
| 355 |
|
| 356 |
analysis = f"\n📝 File Analysis:\n"
|
| 357 |
if is_code:
|
| 358 |
-
functions =
|
| 359 |
-
classes =
|
| 360 |
-
imports =
|
| 361 |
-
|
|
|
|
|
|
|
| 362 |
analysis += f"- File Type: Code\n"
|
| 363 |
analysis += f"- Total Lines: {total_lines:,}\n"
|
| 364 |
analysis += f"- Functions: {functions}\n"
|
|
@@ -375,14 +382,18 @@ def read_uploaded_file(file):
|
|
| 375 |
analysis += f"- Character Count: {chars:,}\n"
|
| 376 |
|
| 377 |
return content + analysis, "text"
|
|
|
|
| 378 |
except UnicodeDecodeError:
|
| 379 |
continue
|
| 380 |
-
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
except Exception as e:
|
| 383 |
return f"Error reading file: {str(e)}", "error"
|
| 384 |
|
| 385 |
-
|
| 386 |
CSS = """
|
| 387 |
/* 3D 스타일 CSS */
|
| 388 |
:root {
|
|
@@ -539,22 +550,20 @@ body {
|
|
| 539 |
"""
|
| 540 |
|
| 541 |
def clear_cuda_memory():
|
|
|
|
| 542 |
if hasattr(torch.cuda, 'empty_cache'):
|
| 543 |
with torch.cuda.device('cuda'):
|
| 544 |
torch.cuda.empty_cache()
|
| 545 |
|
| 546 |
-
|
| 547 |
@spaces.GPU
|
| 548 |
def load_model():
|
| 549 |
try:
|
| 550 |
-
# 메모리 정리 먼저 수행
|
| 551 |
clear_cuda_memory()
|
| 552 |
-
|
| 553 |
loaded_model = AutoModelForCausalLM.from_pretrained(
|
| 554 |
MODEL_ID,
|
| 555 |
torch_dtype=torch.bfloat16,
|
| 556 |
device_map="auto",
|
| 557 |
-
# 낮은 메모리 사용을 위한 설정 추가
|
| 558 |
low_cpu_mem_usage=True,
|
| 559 |
)
|
| 560 |
return loaded_model
|
|
@@ -562,22 +571,8 @@ def load_model():
|
|
| 562 |
print(f"모델 로드 오류: {str(e)}")
|
| 563 |
raise
|
| 564 |
|
| 565 |
-
def _truncate_tokens_for_context(input_ids_str: str, desired_input_length: int) -> str:
|
| 566 |
-
"""
|
| 567 |
-
입력 문자열이 desired_input_length 토큰을 넘으면, 앞부분(오래된 컨텍스트)을 잘라내는 함수.
|
| 568 |
-
"""
|
| 569 |
-
tokens = input_ids_str.split()
|
| 570 |
-
if len(tokens) > desired_input_length:
|
| 571 |
-
tokens = tokens[-desired_input_length:]
|
| 572 |
-
return " ".join(tokens)
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
# build_prompt 함수: 대화 내역을 문자열로 변환
|
| 576 |
def build_prompt(conversation: list) -> str:
|
| 577 |
-
"""
|
| 578 |
-
conversation은 각 항목이 {"role": "user" 또는 "assistant", "content": ...} 형태의 딕셔너리 목록입니다.
|
| 579 |
-
이를 단순 텍스트 프롬프트로 변환합니다.
|
| 580 |
-
"""
|
| 581 |
prompt = ""
|
| 582 |
for msg in conversation:
|
| 583 |
if msg["role"] == "user":
|
|
@@ -587,7 +582,7 @@ def build_prompt(conversation: list) -> str:
|
|
| 587 |
prompt += "Assistant: "
|
| 588 |
return prompt
|
| 589 |
|
| 590 |
-
|
| 591 |
@spaces.GPU
|
| 592 |
def stream_chat(
|
| 593 |
message: str,
|
|
@@ -602,13 +597,14 @@ def stream_chat(
|
|
| 602 |
global model, current_file_context
|
| 603 |
|
| 604 |
try:
|
|
|
|
| 605 |
if model is None:
|
| 606 |
model = load_model()
|
| 607 |
|
| 608 |
-
print(f'
|
| 609 |
-
print(f'
|
| 610 |
|
| 611 |
-
# 파일 업로드 처리
|
| 612 |
file_context = ""
|
| 613 |
if uploaded_file and message == "파일을 분석하고 있습니다...":
|
| 614 |
current_file_context = None
|
|
@@ -623,23 +619,16 @@ def stream_chat(
|
|
| 623 |
current_file_context = file_context
|
| 624 |
message = "업로드된 파일을 분석해주세요."
|
| 625 |
except Exception as e:
|
| 626 |
-
print(f"파일 분석
|
| 627 |
file_context = f"\n\n❌ 파일 분석 중 오류가 발생했습니다: {str(e)}"
|
| 628 |
elif current_file_context:
|
| 629 |
file_context = current_file_context
|
| 630 |
|
| 631 |
-
|
| 632 |
-
print(f"CUDA 메모리 사용량: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
| 633 |
-
|
| 634 |
-
max_history_length = 10
|
| 635 |
-
if len(history) > max_history_length:
|
| 636 |
-
history = history[-max_history_length:]
|
| 637 |
-
|
| 638 |
-
# 위키피디아 컨텍스트 검색
|
| 639 |
wiki_context = ""
|
| 640 |
try:
|
| 641 |
relevant_contexts = find_relevant_context(message)
|
| 642 |
-
if relevant_contexts:
|
| 643 |
wiki_context = "\n\n관련 위키피디아 정보:\n"
|
| 644 |
for ctx in relevant_contexts:
|
| 645 |
wiki_context += (
|
|
@@ -648,9 +637,13 @@ def stream_chat(
|
|
| 648 |
f"유사도: {ctx['similarity']:.3f}\n\n"
|
| 649 |
)
|
| 650 |
except Exception as e:
|
| 651 |
-
print(f"컨텍스트 검색
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
|
| 653 |
-
# 대화 내역 구성
|
| 654 |
conversation = []
|
| 655 |
for prompt, answer in history:
|
| 656 |
conversation.extend([
|
|
@@ -658,7 +651,7 @@ def stream_chat(
|
|
| 658 |
{"role": "assistant", "content": answer}
|
| 659 |
])
|
| 660 |
|
| 661 |
-
# 최종 메시지
|
| 662 |
final_message = message
|
| 663 |
if file_context:
|
| 664 |
final_message = file_context + "\n현재 질문: " + message
|
|
@@ -666,53 +659,42 @@ def stream_chat(
|
|
| 666 |
final_message = wiki_context + "\n현재 질문: " + message
|
| 667 |
if file_context and wiki_context:
|
| 668 |
final_message = file_context + wiki_context + "\n현재 질문: " + message
|
| 669 |
-
|
| 670 |
conversation.append({"role": "user", "content": final_message})
|
| 671 |
|
| 672 |
-
#
|
| 673 |
input_ids_str = build_prompt(conversation)
|
| 674 |
-
|
| 675 |
-
# 먼저 컨텍스트 길이 확인 및 제한
|
| 676 |
max_context = 8192
|
| 677 |
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
|
| 678 |
input_length = tokenized_input["input_ids"].shape[1]
|
| 679 |
-
|
| 680 |
-
# 컨텍스트가 너무 길면 자르기
|
| 681 |
if input_length > max_context - max_new_tokens:
|
| 682 |
-
print(f"입력이 너무 깁니다: {input_length}
|
| 683 |
-
# 최소 생성 토큰 수 확보
|
| 684 |
min_generation = min(256, max_new_tokens)
|
| 685 |
new_desired_input_length = max_context - min_generation
|
| 686 |
-
|
| 687 |
-
# 입력 텍스트를 토큰 단위로 자르기
|
| 688 |
tokens = tokenizer.encode(input_ids_str)
|
| 689 |
if len(tokens) > new_desired_input_length:
|
| 690 |
tokens = tokens[-new_desired_input_length:]
|
| 691 |
input_ids_str = tokenizer.decode(tokens)
|
| 692 |
-
|
| 693 |
-
# 다시 토큰화
|
| 694 |
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
|
| 695 |
input_length = tokenized_input["input_ids"].shape[1]
|
| 696 |
-
|
| 697 |
-
print(f"
|
| 698 |
-
|
| 699 |
-
# CUDA로 입력 이동
|
| 700 |
inputs = tokenized_input.to("cuda")
|
| 701 |
-
|
| 702 |
-
# 남은 토큰
|
| 703 |
remaining = max_context - input_length
|
| 704 |
if remaining < max_new_tokens:
|
| 705 |
-
print(f"max_new_tokens
|
| 706 |
max_new_tokens = remaining
|
| 707 |
|
| 708 |
-
print(f"입력 텐서 생성 후 CUDA 메모리: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
| 709 |
-
|
| 710 |
# 스트리머 설정
|
| 711 |
streamer = TextIteratorStreamer(
|
| 712 |
tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
|
| 713 |
)
|
| 714 |
-
|
| 715 |
-
# 생성
|
| 716 |
generate_kwargs = dict(
|
| 717 |
**inputs,
|
| 718 |
streamer=streamer,
|
|
@@ -727,63 +709,56 @@ def stream_chat(
|
|
| 727 |
use_cache=True
|
| 728 |
)
|
| 729 |
|
| 730 |
-
# 메모리 정리
|
| 731 |
clear_cuda_memory()
|
| 732 |
|
| 733 |
-
# 별도 스레드에서 생성
|
| 734 |
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 735 |
thread.start()
|
| 736 |
|
| 737 |
-
#
|
| 738 |
buffer = ""
|
| 739 |
partial_message = ""
|
| 740 |
last_yield_time = time.time()
|
| 741 |
-
|
| 742 |
try:
|
| 743 |
for new_text in streamer:
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
print(f"개별 토큰 처리 중 오류: {str(inner_e)}")
|
| 756 |
-
continue
|
| 757 |
-
|
| 758 |
-
# 마지막 응답 확인
|
| 759 |
if buffer:
|
| 760 |
yield "", history + [[message, buffer]]
|
| 761 |
-
|
| 762 |
-
# 대화
|
| 763 |
chat_history.add_conversation(message, buffer)
|
| 764 |
-
|
| 765 |
except Exception as e:
|
| 766 |
-
print(f"스트리밍 중 오류
|
| 767 |
-
if not buffer: #
|
| 768 |
-
buffer = f"응답 생성 중
|
| 769 |
yield "", history + [[message, buffer]]
|
| 770 |
-
|
| 771 |
-
# 스레드가 여전히 실행 중이면 종료 대기
|
| 772 |
if thread.is_alive():
|
| 773 |
thread.join(timeout=5.0)
|
| 774 |
-
|
| 775 |
-
# 메모리 정리
|
| 776 |
clear_cuda_memory()
|
| 777 |
|
| 778 |
except Exception as e:
|
| 779 |
import traceback
|
| 780 |
error_details = traceback.format_exc()
|
| 781 |
error_message = f"오류가 발생했습니다: {str(e)}\n{error_details}"
|
| 782 |
-
print(f"Stream chat
|
| 783 |
clear_cuda_memory()
|
| 784 |
yield "", history + [[message, error_message]]
|
| 785 |
|
| 786 |
-
|
| 787 |
def create_demo():
|
| 788 |
with gr.Blocks(css=CSS) as demo:
|
| 789 |
with gr.Column(elem_classes="markdown-style"):
|
|
@@ -834,6 +809,7 @@ def create_demo():
|
|
| 834 |
scale=1
|
| 835 |
)
|
| 836 |
|
|
|
|
| 837 |
with gr.Accordion("🎮 Advanced Settings", open=False):
|
| 838 |
with gr.Row():
|
| 839 |
with gr.Column(scale=1):
|
|
@@ -859,6 +835,7 @@ def create_demo():
|
|
| 859 |
label="Repetition Penalty 🔄"
|
| 860 |
)
|
| 861 |
|
|
|
|
| 862 |
gr.Examples(
|
| 863 |
examples=[
|
| 864 |
["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
|
|
@@ -869,23 +846,25 @@ def create_demo():
|
|
| 869 |
inputs=msg
|
| 870 |
)
|
| 871 |
|
|
|
|
| 872 |
def clear_conversation():
|
| 873 |
global current_file_context
|
| 874 |
current_file_context = None
|
| 875 |
return [], None, "Start a new conversation..."
|
| 876 |
|
|
|
|
| 877 |
msg.submit(
|
| 878 |
stream_chat,
|
| 879 |
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
|
| 880 |
outputs=[msg, chatbot]
|
| 881 |
)
|
| 882 |
-
|
| 883 |
send.click(
|
| 884 |
stream_chat,
|
| 885 |
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
|
| 886 |
outputs=[msg, chatbot]
|
| 887 |
)
|
| 888 |
|
|
|
|
| 889 |
file_upload.change(
|
| 890 |
fn=lambda: ("처리 중...", [["시스템", "파일을 분석 중입니다. 잠시만 기다려주세요..."]]),
|
| 891 |
outputs=[msg, chatbot],
|
|
@@ -901,6 +880,7 @@ def create_demo():
|
|
| 901 |
queue=True
|
| 902 |
)
|
| 903 |
|
|
|
|
| 904 |
clear.click(
|
| 905 |
fn=clear_conversation,
|
| 906 |
outputs=[chatbot, file_upload, msg],
|
|
@@ -909,7 +889,7 @@ def create_demo():
|
|
| 909 |
|
| 910 |
return demo
|
| 911 |
|
| 912 |
-
|
| 913 |
if __name__ == "__main__":
|
| 914 |
demo = create_demo()
|
| 915 |
-
demo.launch()
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
+
# 1) Dynamo 완전 비활성화
|
| 4 |
os.environ["TORCH_DYNAMO_DISABLE"] = "1"
|
| 5 |
|
| 6 |
+
# 2) Triton의 cudagraphs 최적화 비활성화
|
| 7 |
+
os.environ["TRITON_DISABLE_CUDAGRAPHS"] = "1"
|
| 8 |
+
|
| 9 |
+
# 3) 경고 무시 설정 (skipping cudagraphs 관련)
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings("ignore", message="skipping cudagraphs due to mutated inputs")
|
| 12 |
+
warnings.filterwarnings("ignore", message="Not enough SMs to use max_autotune_gemm mode")
|
| 13 |
+
|
| 14 |
import torch
|
| 15 |
+
# TensorFloat32 연산 활성화 (성능 최적화)
|
| 16 |
torch.set_float32_matmul_precision('high')
|
| 17 |
+
|
| 18 |
+
# TorchInductor cudagraphs 비활성화
|
| 19 |
import torch._inductor
|
| 20 |
torch._inductor.config.triton.cudagraphs = False
|
| 21 |
+
|
| 22 |
+
# Dynamo suppress_errors 옵션 (오류 시 eager로 fallback)
|
| 23 |
import torch._dynamo
|
| 24 |
+
torch._dynamo.config.suppress_errors = True
|
| 25 |
+
|
| 26 |
import gradio as gr
|
| 27 |
import spaces
|
| 28 |
+
|
| 29 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 30 |
|
| 31 |
from threading import Thread
|
|
|
|
| 40 |
import pyarrow.parquet as pq
|
| 41 |
import pypdf
|
| 42 |
import io
|
|
|
|
|
|
|
| 43 |
import platform
|
| 44 |
import subprocess
|
| 45 |
import pytesseract
|
| 46 |
from pdf2image import convert_from_path
|
| 47 |
+
import queue # queue.Empty 예외 처리를 위해
|
| 48 |
+
import time # 스트리밍 타이밍을 위해
|
| 49 |
|
| 50 |
+
# -------------------- PDF to Markdown 변환 관련 import --------------------
|
| 51 |
try:
|
| 52 |
import re
|
| 53 |
import requests
|
|
|
|
| 64 |
)
|
| 65 |
# ---------------------------------------------------------------------------
|
| 66 |
|
|
|
|
|
|
|
|
|
|
| 67 |
# 전역 변수
|
| 68 |
current_file_context = None
|
| 69 |
|
|
|
|
| 73 |
MODELS = os.environ.get("MODELS")
|
| 74 |
MODEL_NAME = MODEL_ID.split("/")[-1]
|
| 75 |
|
| 76 |
+
model = None # 전역에서 관리
|
| 77 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 78 |
|
| 79 |
+
# (1) 위키피디아 데이터셋 로드
|
| 80 |
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
|
| 81 |
print("Wikipedia dataset loaded:", wiki_dataset)
|
| 82 |
|
| 83 |
+
# (2) TF-IDF 벡터라이저 초기화 및 학습
|
| 84 |
print("TF-IDF 벡터화 시작...")
|
| 85 |
questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 사용
|
| 86 |
vectorizer = TfidfVectorizer(max_features=1000)
|
| 87 |
question_vectors = vectorizer.fit_transform(questions)
|
| 88 |
print("TF-IDF 벡터화 완료")
|
| 89 |
|
| 90 |
+
# ------------------------- ChatHistory 클래스 -------------------------
|
| 91 |
class ChatHistory:
|
| 92 |
def __init__(self):
|
| 93 |
self.history = []
|
|
|
|
| 143 |
print(f"히스토리 로드 실패: {e}")
|
| 144 |
self.history = []
|
| 145 |
|
| 146 |
+
# 전역 ChatHistory 인스턴스
|
|
|
|
| 147 |
chat_history = ChatHistory()
|
| 148 |
|
| 149 |
+
# ------------------------- 위키 문서 검색 (TF-IDF) -------------------------
|
| 150 |
def find_relevant_context(query, top_k=3):
|
| 151 |
# 쿼리 벡터화
|
| 152 |
query_vector = vectorizer.transform([query])
|
| 153 |
+
# 코사인 유사도
|
| 154 |
similarities = (query_vector * question_vectors.T).toarray()[0]
|
| 155 |
+
# 유사도 높은 질문 인덱스
|
| 156 |
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
| 157 |
+
|
| 158 |
relevant_contexts = []
|
| 159 |
for idx in top_indices:
|
| 160 |
if similarities[idx] > 0:
|
|
|
|
| 165 |
})
|
| 166 |
return relevant_contexts
|
| 167 |
|
| 168 |
+
# 파일 업로드 시 표시할 초기 메시지
|
| 169 |
def init_msg():
|
| 170 |
return "파일을 분석하고 있습니다..."
|
| 171 |
|
|
|
|
| 172 |
# -------------------- PDF 파일을 Markdown으로 변환하는 유틸 함수들 --------------------
|
| 173 |
def extract_text_from_pdf(reader: PdfReader) -> str:
|
| 174 |
"""
|
| 175 |
PyPDF를 사용해 모든 페이지 텍스트를 추출.
|
|
|
|
| 176 |
"""
|
| 177 |
full_text = ""
|
| 178 |
for idx, page in enumerate(reader.pages):
|
|
|
|
| 181 |
full_text += f"---- Page {idx+1} ----\n" + text + "\n\n"
|
| 182 |
return full_text.strip()
|
| 183 |
|
|
|
|
| 184 |
def convert_pdf_to_markdown(pdf_file: str):
|
| 185 |
"""
|
| 186 |
+
PDF 파일에서 텍스트를 추출하고,
|
| 187 |
+
이미지가 많고 텍스트가 적으면 OCR 시도
|
|
|
|
|
|
|
| 188 |
"""
|
| 189 |
try:
|
| 190 |
reader = PdfReader(pdf_file)
|
| 191 |
except Exception as e:
|
| 192 |
return f"PDF 파일을 읽는 중 오류 발생: {e}", None, None
|
| 193 |
|
| 194 |
+
# 메타데이터 추출
|
| 195 |
raw_meta = reader.metadata
|
| 196 |
metadata = {
|
| 197 |
"author": raw_meta.author if raw_meta else None,
|
|
|
|
| 201 |
"title": raw_meta.title if raw_meta else None,
|
| 202 |
}
|
| 203 |
|
| 204 |
+
# 텍스트 추출
|
| 205 |
full_text = extract_text_from_pdf(reader)
|
| 206 |
|
| 207 |
+
# 이미지-텍스트 비율 판단 후 OCR 시도
|
| 208 |
+
image_count = sum(len(page.images) for page in reader.pages)
|
|
|
|
|
|
|
|
|
|
| 209 |
if image_count > 0 and len(full_text) < 1000:
|
| 210 |
try:
|
| 211 |
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
|
| 212 |
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
|
| 213 |
+
# OCR된 PDF 다시 읽기
|
| 214 |
reader_ocr = PdfReader(out_pdf_file)
|
| 215 |
full_text = extract_text_from_pdf(reader_ocr)
|
| 216 |
except Exception as e:
|
|
|
|
| 218 |
|
| 219 |
return full_text, metadata, pdf_file
|
| 220 |
|
| 221 |
+
# ------------------------- 파일 분석 함수 -------------------------
|
|
|
|
|
|
|
| 222 |
def analyze_file_content(content, file_type):
|
| 223 |
+
"""간단한 구조 분석/요약."""
|
| 224 |
if file_type in ['parquet', 'csv']:
|
| 225 |
try:
|
| 226 |
lines = content.split('\n')
|
|
|
|
| 245 |
words = len(content.split())
|
| 246 |
return f"📝 Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"
|
| 247 |
|
|
|
|
| 248 |
def read_uploaded_file(file):
|
| 249 |
"""
|
| 250 |
+
업로드된 파일 처리 -> 내용/타입
|
|
|
|
|
|
|
| 251 |
"""
|
| 252 |
if file is None:
|
| 253 |
return "", ""
|
| 254 |
+
|
| 255 |
try:
|
| 256 |
file_ext = os.path.splitext(file.name)[1].lower()
|
| 257 |
|
|
|
|
| 265 |
content += f"1. Basic Information:\n"
|
| 266 |
content += f"- Total Rows: {len(df):,}\n"
|
| 267 |
content += f"- Total Columns: {len(df.columns)}\n"
|
| 268 |
+
mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024
|
| 269 |
+
content += f"- Memory Usage: {mem_usage:.2f} MB\n\n"
|
| 270 |
|
| 271 |
content += f"2. Column Information:\n"
|
| 272 |
for col in df.columns:
|
|
|
|
| 278 |
content += f"\n\n4. Missing Values:\n"
|
| 279 |
null_counts = df.isnull().sum()
|
| 280 |
for col, count in null_counts[null_counts > 0].items():
|
| 281 |
+
rate = count / len(df) * 100
|
| 282 |
+
content += f"- {col}: {count:,} ({rate:.1f}%)\n"
|
| 283 |
|
| 284 |
numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
|
| 285 |
if len(numeric_cols) > 0:
|
|
|
|
| 291 |
except Exception as e:
|
| 292 |
return f"Error reading Parquet file: {str(e)}", "error"
|
| 293 |
|
| 294 |
+
# PDF
|
| 295 |
if file_ext == '.pdf':
|
| 296 |
try:
|
| 297 |
markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
|
|
|
|
| 305 |
|
| 306 |
content += "## Extracted Text\n\n"
|
| 307 |
content += markdown_text
|
|
|
|
| 308 |
return content, "pdf"
|
| 309 |
except Exception as e:
|
| 310 |
return f"Error reading PDF file: {str(e)}", "error"
|
|
|
|
| 319 |
content += f"1. Basic Information:\n"
|
| 320 |
content += f"- Total Rows: {len(df):,}\n"
|
| 321 |
content += f"- Total Columns: {len(df.columns)}\n"
|
| 322 |
+
mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024
|
| 323 |
+
content += f"- Memory Usage: {mem_usage:.2f} MB\n\n"
|
| 324 |
|
| 325 |
content += f"2. Column Information:\n"
|
| 326 |
for col in df.columns:
|
|
|
|
| 332 |
content += f"\n\n4. Missing Values:\n"
|
| 333 |
null_counts = df.isnull().sum()
|
| 334 |
for col, count in null_counts[null_counts > 0].items():
|
| 335 |
+
rate = count / len(df) * 100
|
| 336 |
+
content += f"- {col}: {count:,} ({rate:.1f}%)\n"
|
| 337 |
|
| 338 |
return content, "csv"
|
| 339 |
except UnicodeDecodeError:
|
| 340 |
continue
|
| 341 |
+
raise UnicodeDecodeError(
|
| 342 |
+
f"Unable to read file with supported encodings ({', '.join(encodings)})"
|
| 343 |
+
)
|
| 344 |
|
| 345 |
+
# 텍스트 파일
|
| 346 |
else:
|
| 347 |
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
|
| 348 |
for encoding in encodings:
|
|
|
|
| 353 |
lines = content.split('\n')
|
| 354 |
total_lines = len(lines)
|
| 355 |
non_empty_lines = len([line for line in lines if line.strip()])
|
| 356 |
+
is_code = any(
|
| 357 |
+
keyword in content.lower()
|
| 358 |
+
for keyword in ['def ', 'class ', 'import ', 'function']
|
| 359 |
+
)
|
| 360 |
|
| 361 |
analysis = f"\n📝 File Analysis:\n"
|
| 362 |
if is_code:
|
| 363 |
+
functions = sum('def ' in line for line in lines)
|
| 364 |
+
classes = sum('class ' in line for line in lines)
|
| 365 |
+
imports = sum(
|
| 366 |
+
('import ' in line) or ('from ' in line)
|
| 367 |
+
for line in lines
|
| 368 |
+
)
|
| 369 |
analysis += f"- File Type: Code\n"
|
| 370 |
analysis += f"- Total Lines: {total_lines:,}\n"
|
| 371 |
analysis += f"- Functions: {functions}\n"
|
|
|
|
| 382 |
analysis += f"- Character Count: {chars:,}\n"
|
| 383 |
|
| 384 |
return content + analysis, "text"
|
| 385 |
+
|
| 386 |
except UnicodeDecodeError:
|
| 387 |
continue
|
| 388 |
+
|
| 389 |
+
raise UnicodeDecodeError(
|
| 390 |
+
f"Unable to read file with supported encodings ({', '.join(encodings)})"
|
| 391 |
+
)
|
| 392 |
|
| 393 |
except Exception as e:
|
| 394 |
return f"Error reading file: {str(e)}", "error"
|
| 395 |
|
| 396 |
+
# ------------------------- CSS -------------------------
|
| 397 |
CSS = """
|
| 398 |
/* 3D 스타일 CSS */
|
| 399 |
:root {
|
|
|
|
| 550 |
"""
|
| 551 |
|
| 552 |
def clear_cuda_memory():
|
| 553 |
+
"""CUDA 캐시 정리."""
|
| 554 |
if hasattr(torch.cuda, 'empty_cache'):
|
| 555 |
with torch.cuda.device('cuda'):
|
| 556 |
torch.cuda.empty_cache()
|
| 557 |
|
| 558 |
+
# ------------------------- 모델 로딩 함수 -------------------------
|
| 559 |
@spaces.GPU
|
| 560 |
def load_model():
|
| 561 |
try:
|
|
|
|
| 562 |
clear_cuda_memory()
|
|
|
|
| 563 |
loaded_model = AutoModelForCausalLM.from_pretrained(
|
| 564 |
MODEL_ID,
|
| 565 |
torch_dtype=torch.bfloat16,
|
| 566 |
device_map="auto",
|
|
|
|
| 567 |
low_cpu_mem_usage=True,
|
| 568 |
)
|
| 569 |
return loaded_model
|
|
|
|
| 571 |
print(f"모델 로드 오류: {str(e)}")
|
| 572 |
raise
|
| 573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
def build_prompt(conversation: list) -> str:
|
| 575 |
+
"""대화 내역을 단순 텍스트 프롬프트로 변환."""
|
|
|
|
|
|
|
|
|
|
| 576 |
prompt = ""
|
| 577 |
for msg in conversation:
|
| 578 |
if msg["role"] == "user":
|
|
|
|
| 582 |
prompt += "Assistant: "
|
| 583 |
return prompt
|
| 584 |
|
| 585 |
+
# ------------------------- 메시지 스트리밍 함수 -------------------------
|
| 586 |
@spaces.GPU
|
| 587 |
def stream_chat(
|
| 588 |
message: str,
|
|
|
|
| 597 |
global model, current_file_context
|
| 598 |
|
| 599 |
try:
|
| 600 |
+
# 모델 미로드시 로딩
|
| 601 |
if model is None:
|
| 602 |
model = load_model()
|
| 603 |
|
| 604 |
+
print(f'[User input] message: {message}')
|
| 605 |
+
print(f'[History] {history}')
|
| 606 |
|
| 607 |
+
# (1) 파일 업로드 처리
|
| 608 |
file_context = ""
|
| 609 |
if uploaded_file and message == "파일을 분석하고 있습니다...":
|
| 610 |
current_file_context = None
|
|
|
|
| 619 |
current_file_context = file_context
|
| 620 |
message = "업로드된 파일을 분석해주세요."
|
| 621 |
except Exception as e:
|
| 622 |
+
print(f"[파일 분석 오류] {str(e)}")
|
| 623 |
file_context = f"\n\n❌ 파일 분석 중 오류가 발생했습니다: {str(e)}"
|
| 624 |
elif current_file_context:
|
| 625 |
file_context = current_file_context
|
| 626 |
|
| 627 |
+
# (2) TF-IDF 기반 관련 문서 탐색
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
wiki_context = ""
|
| 629 |
try:
|
| 630 |
relevant_contexts = find_relevant_context(message)
|
| 631 |
+
if relevant_contexts:
|
| 632 |
wiki_context = "\n\n관련 위키피디아 정보:\n"
|
| 633 |
for ctx in relevant_contexts:
|
| 634 |
wiki_context += (
|
|
|
|
| 637 |
f"유사도: {ctx['similarity']:.3f}\n\n"
|
| 638 |
)
|
| 639 |
except Exception as e:
|
| 640 |
+
print(f"[컨텍스트 검색 오류] {str(e)}")
|
| 641 |
+
|
| 642 |
+
# (3) 대화 이력 구성
|
| 643 |
+
max_history_length = 10
|
| 644 |
+
if len(history) > max_history_length:
|
| 645 |
+
history = history[-max_history_length:]
|
| 646 |
|
|
|
|
| 647 |
conversation = []
|
| 648 |
for prompt, answer in history:
|
| 649 |
conversation.extend([
|
|
|
|
| 651 |
{"role": "assistant", "content": answer}
|
| 652 |
])
|
| 653 |
|
| 654 |
+
# (4) 최종 메시지 결정
|
| 655 |
final_message = message
|
| 656 |
if file_context:
|
| 657 |
final_message = file_context + "\n현재 질문: " + message
|
|
|
|
| 659 |
final_message = wiki_context + "\n현재 질문: " + message
|
| 660 |
if file_context and wiki_context:
|
| 661 |
final_message = file_context + wiki_context + "\n현재 질문: " + message
|
| 662 |
+
|
| 663 |
conversation.append({"role": "user", "content": final_message})
|
| 664 |
|
| 665 |
+
# (5) 토큰화 및 프롬프트 구축
|
| 666 |
input_ids_str = build_prompt(conversation)
|
|
|
|
|
|
|
| 667 |
max_context = 8192
|
| 668 |
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
|
| 669 |
input_length = tokenized_input["input_ids"].shape[1]
|
| 670 |
+
|
| 671 |
+
# (6) 컨텍스트가 너무 길면 앞부분 토큰 자르기
|
| 672 |
if input_length > max_context - max_new_tokens:
|
| 673 |
+
print(f"[경고] 입력이 너무 깁니다: {input_length} 토큰 -> 잘라냄.")
|
|
|
|
| 674 |
min_generation = min(256, max_new_tokens)
|
| 675 |
new_desired_input_length = max_context - min_generation
|
|
|
|
|
|
|
| 676 |
tokens = tokenizer.encode(input_ids_str)
|
| 677 |
if len(tokens) > new_desired_input_length:
|
| 678 |
tokens = tokens[-new_desired_input_length:]
|
| 679 |
input_ids_str = tokenizer.decode(tokens)
|
|
|
|
|
|
|
| 680 |
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
|
| 681 |
input_length = tokenized_input["input_ids"].shape[1]
|
| 682 |
+
|
| 683 |
+
print(f"[토큰 길이] {input_length}")
|
|
|
|
|
|
|
| 684 |
inputs = tokenized_input.to("cuda")
|
| 685 |
+
|
| 686 |
+
# 남은 토큰 수로 max_new_tokens 조정
|
| 687 |
remaining = max_context - input_length
|
| 688 |
if remaining < max_new_tokens:
|
| 689 |
+
print(f"[max_new_tokens 조정] {max_new_tokens} -> {remaining}")
|
| 690 |
max_new_tokens = remaining
|
| 691 |
|
|
|
|
|
|
|
| 692 |
# 스트리머 설정
|
| 693 |
streamer = TextIteratorStreamer(
|
| 694 |
tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
|
| 695 |
)
|
| 696 |
+
|
| 697 |
+
# (7) 생성 파라미터
|
| 698 |
generate_kwargs = dict(
|
| 699 |
**inputs,
|
| 700 |
streamer=streamer,
|
|
|
|
| 709 |
use_cache=True
|
| 710 |
)
|
| 711 |
|
|
|
|
| 712 |
clear_cuda_memory()
|
| 713 |
|
| 714 |
+
# (8) 별도 스레드에서 생성
|
| 715 |
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 716 |
thread.start()
|
| 717 |
|
| 718 |
+
# (9) 스트리밍 응답
|
| 719 |
buffer = ""
|
| 720 |
partial_message = ""
|
| 721 |
last_yield_time = time.time()
|
| 722 |
+
|
| 723 |
try:
|
| 724 |
for new_text in streamer:
|
| 725 |
+
buffer += new_text
|
| 726 |
+
partial_message += new_text
|
| 727 |
+
|
| 728 |
+
# 일정 시간 또는 버퍼 길이 기준으로 yield
|
| 729 |
+
current_time = time.time()
|
| 730 |
+
if (current_time - last_yield_time > 0.1) or (len(partial_message) > 20):
|
| 731 |
+
yield "", history + [[message, buffer]]
|
| 732 |
+
partial_message = ""
|
| 733 |
+
last_yield_time = current_time
|
| 734 |
+
|
| 735 |
+
# 마지막 완성된 응답
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
if buffer:
|
| 737 |
yield "", history + [[message, buffer]]
|
| 738 |
+
|
| 739 |
+
# 대화 내용 저장
|
| 740 |
chat_history.add_conversation(message, buffer)
|
| 741 |
+
|
| 742 |
except Exception as e:
|
| 743 |
+
print(f"[스트리밍 중 오류] {str(e)}")
|
| 744 |
+
if not buffer: # buffer가 비어있다면 오류메시지 대화창 표시
|
| 745 |
+
buffer = f"응답 생성 중 오류 발생: {str(e)}"
|
| 746 |
yield "", history + [[message, buffer]]
|
| 747 |
+
|
|
|
|
| 748 |
if thread.is_alive():
|
| 749 |
thread.join(timeout=5.0)
|
| 750 |
+
|
|
|
|
| 751 |
clear_cuda_memory()
|
| 752 |
|
| 753 |
except Exception as e:
|
| 754 |
import traceback
|
| 755 |
error_details = traceback.format_exc()
|
| 756 |
error_message = f"오류가 발생했습니다: {str(e)}\n{error_details}"
|
| 757 |
+
print(f"[Stream chat 오류] {error_message}")
|
| 758 |
clear_cuda_memory()
|
| 759 |
yield "", history + [[message, error_message]]
|
| 760 |
|
| 761 |
+
# ------------------------- Gradio UI 구성 -------------------------
|
| 762 |
def create_demo():
|
| 763 |
with gr.Blocks(css=CSS) as demo:
|
| 764 |
with gr.Column(elem_classes="markdown-style"):
|
|
|
|
| 809 |
scale=1
|
| 810 |
)
|
| 811 |
|
| 812 |
+
# 고급 설정
|
| 813 |
with gr.Accordion("🎮 Advanced Settings", open=False):
|
| 814 |
with gr.Row():
|
| 815 |
with gr.Column(scale=1):
|
|
|
|
| 835 |
label="Repetition Penalty 🔄"
|
| 836 |
)
|
| 837 |
|
| 838 |
+
# 예시
|
| 839 |
gr.Examples(
|
| 840 |
examples=[
|
| 841 |
["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
|
|
|
|
| 846 |
inputs=msg
|
| 847 |
)
|
| 848 |
|
| 849 |
+
# 대화 내용 초기화
|
| 850 |
def clear_conversation():
|
| 851 |
global current_file_context
|
| 852 |
current_file_context = None
|
| 853 |
return [], None, "Start a new conversation..."
|
| 854 |
|
| 855 |
+
# 메시지 전송
|
| 856 |
msg.submit(
|
| 857 |
stream_chat,
|
| 858 |
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
|
| 859 |
outputs=[msg, chatbot]
|
| 860 |
)
|
|
|
|
| 861 |
send.click(
|
| 862 |
stream_chat,
|
| 863 |
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
|
| 864 |
outputs=[msg, chatbot]
|
| 865 |
)
|
| 866 |
|
| 867 |
+
# 파일 업로드 이벤트
|
| 868 |
file_upload.change(
|
| 869 |
fn=lambda: ("처리 중...", [["시스템", "파일을 분석 중입니다. 잠시만 기다려주세요..."]]),
|
| 870 |
outputs=[msg, chatbot],
|
|
|
|
| 880 |
queue=True
|
| 881 |
)
|
| 882 |
|
| 883 |
+
# Clear 버튼
|
| 884 |
clear.click(
|
| 885 |
fn=clear_conversation,
|
| 886 |
outputs=[chatbot, file_upload, msg],
|
|
|
|
| 889 |
|
| 890 |
return demo
|
| 891 |
|
| 892 |
+
# 메인 실행
|
| 893 |
if __name__ == "__main__":
|
| 894 |
demo = create_demo()
|
| 895 |
+
demo.launch()
|