import gradio as gr import spaces import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from threading import Thread import re import json from datetime import datetime import math import os # ═══════════════════════════════════════════════════════════ # 🔧 모델 로딩 # ═══════════════════════════════════════════════════════════ MODEL_ID = "zai-org/GLM-4.7-Flash" print(f"[Init] Loading tokenizer from {MODEL_ID}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = None def get_model(): global model if model is None: print("[Model] Loading model with bfloat16...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, ) print(f"[Model] Model loaded on {model.device}") return model # ═══════════════════════════════════════════════════════════ # 📄 파일 처리 함수 # ═══════════════════════════════════════════════════════════ def extract_text_from_pdf(file_path: str) -> str: """PDF 파일에서 텍스트 추출""" try: import fitz doc = fitz.open(file_path) text_parts = [] for page_num, page in enumerate(doc, 1): text = page.get_text() if text.strip(): text_parts.append(f"[페이지 {page_num}]\n{text}") doc.close() return "\n\n".join(text_parts) if text_parts else "[PDF에서 텍스트를 추출할 수 없습니다]" except ImportError: try: from pypdf import PdfReader reader = PdfReader(file_path) text_parts = [] for page_num, page in enumerate(reader.pages, 1): text = page.extract_text() if text and text.strip(): text_parts.append(f"[페이지 {page_num}]\n{text}") return "\n\n".join(text_parts) if text_parts else "[PDF에서 텍스트를 추출할 수 없습니다]" except Exception as e: return f"[PDF 읽기 오류: {str(e)}]" except Exception as e: return f"[PDF 읽기 오류: {str(e)}]" def extract_text_from_docx(file_path: str) -> str: """DOCX 파일에서 텍스트 추출""" try: from docx import Document doc = Document(file_path) text_parts = [] for para in doc.paragraphs: if para.text.strip(): text_parts.append(para.text) for table_idx, table in enumerate(doc.tables, 1): table_text = [f"\n[표 {table_idx}]"] for row in table.rows: row_text = " | ".join(cell.text.strip() for cell in row.cells) if row_text.strip(): table_text.append(row_text) if len(table_text) > 1: text_parts.append("\n".join(table_text)) return "\n\n".join(text_parts) if text_parts else "[DOCX에서 텍스트를 추출할 수 없습니다]" except Exception as e: return f"[DOCX 읽기 오류: {str(e)}]" def extract_text_from_txt(file_path: str) -> str: """TXT 파일에서 텍스트 추출""" try: encodings = ['utf-8', 'cp949', 'euc-kr', 'latin-1'] for encoding in encodings: try: with open(file_path, 'r', encoding=encoding) as f: return f.read() except UnicodeDecodeError: continue return "[텍스트 파일 인코딩을 인식할 수 없습니다]" except Exception as e: return f"[TXT 읽기 오류: {str(e)}]" def process_uploaded_file(file) -> tuple: """업로드된 파일 처리""" if file is None: return "", "" file_path = file.name if hasattr(file, 'name') else str(file) file_name = os.path.basename(file_path) file_ext = os.path.splitext(file_name)[1].lower() if file_ext == '.pdf': content = extract_text_from_pdf(file_path) elif file_ext == '.docx': content = extract_text_from_docx(file_path) elif file_ext in ['.txt', '.md', '.py', '.js', '.html', '.css', '.json', '.xml', '.csv']: content = extract_text_from_txt(file_path) else: content = f"[지원하지 않는 파일 형식: {file_ext}]" max_chars = 50000 if len(content) > max_chars: content = content[:max_chars] + f"\n\n... [텍스트가 {max_chars}자로 잘렸습니다]" return file_name, content # ═══════════════════════════════════════════════════════════ # 🛠️ Tool Definitions # ═══════════════════════════════════════════════════════════ def execute_tool(tool_name: str, arguments: dict) -> str: """도구 실행""" try: if tool_name == "calculator": expr = arguments.get("expression", "") allowed_names = { "abs": abs, "round": round, "min": min, "max": max, "sum": sum, "pow": pow, "sqrt": math.sqrt, "sin": math.sin, "cos": math.cos, "tan": math.tan, "log": math.log, "log10": math.log10, "exp": math.exp, "pi": math.pi, "e": math.e, "floor": math.floor, "ceil": math.ceil, } expr = re.sub(r'[^0-9+\-*/().a-zA-Z_ ]', '', expr) result = eval(expr, {"__builtins__": {}}, allowed_names) return f"계산 결과: {expr} = {result}" elif tool_name == "get_current_time": tz = arguments.get("timezone", "UTC") now = datetime.now() return f"현재 시간 ({tz}): {now.strftime('%Y-%m-%d %H:%M:%S')}" elif tool_name == "unit_converter": value = arguments.get("value", 0) from_unit = arguments.get("from_unit", "").lower() to_unit = arguments.get("to_unit", "").lower() conversions = { ("km", "m"): lambda x: x * 1000, ("m", "km"): lambda x: x / 1000, ("kg", "g"): lambda x: x * 1000, ("g", "kg"): lambda x: x / 1000, ("c", "f"): lambda x: x * 9/5 + 32, ("f", "c"): lambda x: (x - 32) * 5/9, ("km", "mile"): lambda x: x * 0.621371, ("mile", "km"): lambda x: x * 1.60934, ("kg", "lb"): lambda x: x * 2.20462, ("lb", "kg"): lambda x: x * 0.453592, } key = (from_unit, to_unit) if key in conversions: result = conversions[key](value) return f"변환 결과: {value} {from_unit} = {result:.4f} {to_unit}" else: return f"지원하지 않는 단위 변환: {from_unit} -> {to_unit}" elif tool_name == "code_executor": code = arguments.get("code", "") local_vars = {} safe_builtins = {"print": print, "range": range, "len": len, "str": str, "int": int, "float": float, "list": list, "dict": dict} exec(code, {"__builtins__": safe_builtins}, local_vars) if "result" in local_vars: return f"실행 결과: {local_vars['result']}" return "코드 실행 완료" else: return f"알 수 없는 도구: {tool_name}" except Exception as e: return f"도구 실행 오류: {str(e)}" def parse_tool_calls(response: str) -> list: """응답에서 도구 호출 파싱""" tool_calls = [] patterns = [ r'<\|tool_call\|>(\{.*?\})<\|/tool_call\|>', r'```json\s*(\{[^`]*"name"[^`]*\})\s*```', r'\{"name":\s*"(\w+)",\s*"arguments":\s*(\{[^}]+\})\}', ] for pattern in patterns: matches = re.findall(pattern, response, re.DOTALL) for match in matches: try: if isinstance(match, tuple): tool_call = {"name": match[0], "arguments": json.loads(match[1])} else: tool_call = json.loads(match) tool_calls.append(tool_call) except: continue return tool_calls # ═══════════════════════════════════════════════════════════ # 💬 스트리밍 채팅 함수 (Gradio 6.0 messages format) # ═══════════════════════════════════════════════════════════ file_context = {"name": "", "content": ""} @spaces.GPU(duration=120) def chat_streaming( message: str, history: list, system_prompt: str, max_tokens: int, temperature: float, top_p: float, enable_thinking: bool, enable_tools: bool, ): """스트리밍 채팅 생성 - Gradio 6.0 messages format""" global file_context if not message.strip(): yield history return model = get_model() # 시스템 프롬프트 구성 sys_content = system_prompt if system_prompt.strip() else "You are a helpful AI assistant." if file_context["content"]: sys_content += f"\n\n[업로드된 파일: {file_context['name']}]\n파일 내용:\n---\n{file_context['content']}\n---" if enable_tools: tool_desc = """ You have access to these tools: 1. calculator: Math calculations - {"name": "calculator", "arguments": {"expression": "..."}} 2. get_current_time: Current time - {"name": "get_current_time", "arguments": {}} 3. unit_converter: Unit conversion - {"name": "unit_converter", "arguments": {"value": N, "from_unit": "...", "to_unit": "..."}} 4. code_executor: Run Python - {"name": "code_executor", "arguments": {"code": "..."}} """ sys_content += f"\n\n{tool_desc}" # 모델용 메시지 구성 messages = [{"role": "system", "content": sys_content}] # 히스토리 변환 (Gradio 6.0 format -> 모델 format) for h in history: if isinstance(h, dict): messages.append({"role": h["role"], "content": h["content"]}) elif isinstance(h, (list, tuple)) and len(h) == 2: if h[0]: messages.append({"role": "user", "content": h[0]}) if h[1]: messages.append({"role": "assistant", "content": h[1]}) # 현재 메시지 user_content = message if enable_thinking: user_content = f"\nLet me think step by step.\n\n\n{message}" messages.append({"role": "user", "content": user_content}) # 토크나이즈 try: inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) except Exception as e: new_history = history + [ {"role": "user", "content": message}, {"role": "assistant", "content": f"토크나이즈 오류: {str(e)}"} ] yield new_history return # 스트리머 설정 streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # GenerationConfig 사용 from transformers import GenerationConfig gen_config = GenerationConfig( max_new_tokens=max_tokens, temperature=temperature if temperature > 0 else 0.01, top_p=top_p, do_sample=temperature > 0, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, ) generation_kwargs = { **inputs, "streamer": streamer, "generation_config": gen_config, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Gradio 6.0 messages format으로 히스토리 구성 new_history = history + [ {"role": "user", "content": message}, {"role": "assistant", "content": ""} ] partial_response = "" for new_token in streamer: partial_response += new_token new_history[-1]["content"] = partial_response yield new_history thread.join() # Tool 호출 처리 if enable_tools: tool_calls = parse_tool_calls(partial_response) if tool_calls: tool_results = [] for tc in tool_calls: result = execute_tool(tc.get("name", ""), tc.get("arguments", {})) tool_results.append(result) if tool_results: final_response = partial_response + "\n\n📌 **도구 실행 결과:**\n" + "\n".join(tool_results) new_history[-1]["content"] = final_response yield new_history def handle_file_upload(file): """파일 업로드 처리""" global file_context if file is None: file_context = {"name": "", "content": ""} return "📂 파일이 제거되었습니다." file_name, content = process_uploaded_file(file) if content.startswith("[") and "오류" in content: file_context = {"name": "", "content": ""} return f"❌ {content}" file_context = {"name": file_name, "content": content} preview = content[:500] + "..." if len(content) > 500 else content char_count = len(content) return f"✅ **파일 로드 완료: {file_name}**\n- 문자 수: {char_count:,}자\n\n미리보기:\n```\n{preview}\n```" def clear_file(): """파일 컨텍스트 초기화""" global file_context file_context = {"name": "", "content": ""} return None, "📂 파일이 제거되었습니다." def clear_chat(): """채팅 초기화""" return [] # ═══════════════════════════════════════════════════════════ # 🎨 Gradio UI (6.0 호환 - messages format) # ═══════════════════════════════════════════════════════════ with gr.Blocks(title="GLM-4.7-Flash Chatbot") as demo: gr.Markdown(""" # 🤖 GLM-4.7-Flash Chatbot **30B-A3B MoE 모델 기반 스트리밍 챗봇** | 문서 분석 | Tool Calling 📄 PDF | 📝 DOCX | 📃 TXT | 🧮 계산기 | 🕐 시간조회 | 📐 단위변환 | 🐍 코드실행 """) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( label="대화", height=500, ) with gr.Row(): message = gr.Textbox( label="메시지 입력", placeholder="메시지를 입력하세요...", lines=3, scale=4, ) submit_btn = gr.Button("전송 📤", variant="primary", scale=1) with gr.Row(): clear_btn = gr.Button("대화 초기화 🗑️") stop_btn = gr.Button("생성 중지 ⏹️") with gr.Accordion("📁 문서 업로드 (PDF / DOCX / TXT)", open=True): file_upload = gr.File( label="파일 선택", file_types=[".pdf", ".docx", ".txt", ".md", ".py", ".js", ".html", ".css", ".json", ".xml", ".csv"], file_count="single", ) file_status = gr.Markdown("📂 파일을 업로드하면 내용을 분석할 수 있습니다.") clear_file_btn = gr.Button("📂 파일 제거", size="sm") with gr.Column(scale=1): gr.Markdown("### ⚙️ 설정") system_prompt = gr.Textbox( label="시스템 프롬프트", value="You are a helpful AI assistant. Answer in the same language as the user.", lines=3, ) max_tokens = gr.Slider(64, 4096, value=1024, step=64, label="최대 토큰 수") temperature = gr.Slider(0, 2, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") enable_thinking = gr.Checkbox(label="🧠 Thinking 모드", value=False) enable_tools = gr.Checkbox(label="🛠️ Tool Calling", value=True) gr.Markdown("### 📝 예시") gr.Examples( examples=[ ["안녕하세요!"], ["업로드한 문서를 요약해줘"], ["123 * 456을 계산해줘"], ["현재 시간은?"], ["100km는 몇 마일?"], ], inputs=message, ) # 이벤트 - Gradio 6.0에서는 chatbot만 output submit_event = submit_btn.click( fn=chat_streaming, inputs=[message, chatbot, system_prompt, max_tokens, temperature, top_p, enable_thinking, enable_tools], outputs=[chatbot], ).then( fn=lambda: "", outputs=[message], ) message.submit( fn=chat_streaming, inputs=[message, chatbot, system_prompt, max_tokens, temperature, top_p, enable_thinking, enable_tools], outputs=[chatbot], ).then( fn=lambda: "", outputs=[message], ) clear_btn.click(fn=clear_chat, outputs=[chatbot]) stop_btn.click(fn=None, cancels=[submit_event]) file_upload.change(fn=handle_file_upload, inputs=[file_upload], outputs=[file_status]) clear_file_btn.click(fn=clear_file, outputs=[file_upload, file_status]) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=7860)