Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import traceback | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import fitz # PyMuPDF | |
| from PIL import Image | |
| import io | |
| # --- 1. ์ ์ญ ๋ณ์ ๋ฐ ํ๊ฒฝ ์ค์ --- | |
| tokenizer = None | |
| model = None | |
| MODEL_LOADED = False | |
| # .env ํ์ผ์์ ํ๊ฒฝ ๋ณ์ ๋ก๋ (์ฃผ๋ก ๋ก์ปฌ์์ ์ฌ์ฉ) | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| print("โ .env ํ์ผ ๋ก๋๋จ") | |
| except ImportError: | |
| print("โ ๏ธ python-dotenv๊ฐ ์ค์น๋์ง ์์, ์์คํ ํ๊ฒฝ ๋ณ์ ์ฌ์ฉ") | |
| # ํ๊ฒฝ ๋ณ์์์ ํ ํฐ ๋ฐ ๋ชจ๋ธ ์ด๋ฆ ๊ฐ์ ธ์ค๊ธฐ | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "gbrabbit/lily-math-model") | |
| print(f"๐ ๋ชจ๋ธ: {MODEL_NAME}") | |
| print(f"๐ HF ํ ํฐ: {'โ ์ค์ ๋จ'if HF_TOKEN else 'โ ์ค์ ๋์ง ์์'}") | |
| # --- 2. ํต์ฌ ๋ก์ง: ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋ฉ --- | |
| try: | |
| print("๐ง ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋ฉ ์์...") | |
| # ์ปค์คํ ๋ชจ๋ธ ํด๋์ค import | |
| from modeling import KananaVForConditionalGeneration | |
| if HF_TOKEN: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| token=HF_TOKEN, | |
| trust_remote_code=True | |
| ) | |
| model = KananaVForConditionalGeneration.from_pretrained( | |
| MODEL_NAME, | |
| token=HF_TOKEN, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| device_map="auto" # GPU ์๋ ํ ๋น (์๋ฒ ํ๊ฒฝ์ ํ์) | |
| ) | |
| MODEL_LOADED = True | |
| print("โ ์ปค์คํ ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ!") | |
| else: | |
| print("โ ๏ธ HF ํ ํฐ์ด ์์ด ๊ณต๊ฐ ๋ชจ๋ธ(DialoGPT)๋ก ๋์ฒดํฉ๋๋ค.") | |
| MODEL_NAME = "microsoft/DialoGPT-medium" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto") | |
| MODEL_LOADED = True | |
| except Exception as e: | |
| print(f"โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") | |
| traceback.print_exc() | |
| MODEL_LOADED = False | |
| # --- 3. ํ์ผ ์ฒ๋ฆฌ ์ ํธ๋ฆฌํฐ --- | |
| def extract_text_from_pdf(pdf_file): | |
| try: | |
| doc = fitz.open(stream=pdf_file.read(), filetype="pdf") | |
| text = "".join(page.get_text() for page in doc) | |
| doc.close() | |
| return text | |
| except Exception as e: | |
| print(f"PDF ์ฒ๋ฆฌ ์ค๋ฅ: {e}") | |
| return f"PDF ํ์ผ์ ์ฝ๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}" | |
| def process_uploaded_file(file): | |
| """์ ๋ก๋๋ ํ์ผ์ ํ ์คํธ์ ์ด๋ฏธ์ง ๊ฐ์ฒด๋ก ๋ถ๋ฆฌ""" | |
| if file is None: | |
| return "", None # ํ ์คํธ, ์ด๋ฏธ์ง ์์ | |
| file_path = file.name | |
| file_extension = os.path.splitext(file_path)[1].lower() | |
| if file_extension == '.pdf': | |
| text_content = extract_text_from_pdf(file) | |
| return text_content, None # PDF๋ ํ ์คํธ๋ง, ์ด๋ฏธ์ง๋ ์์ | |
| elif file_extension in ['.png', '.jpg', '.jpeg']: | |
| image = Image.open(file).convert('RGB') | |
| # ์ด๋ฏธ์ง ํ์ผ ์์ฒด๋ฅผ ๋ฐํ (OCR ๋์ ๋ฉํฐ๋ชจ๋ฌ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉ) | |
| return "์ ๋ก๋๋ ์ด๋ฏธ์ง๊ฐ ์์ต๋๋ค.", image | |
| else: | |
| return f"์ง์ํ์ง ์๋ ํ์ผ ํ์: {file_extension}", None | |
| # --- 4. ํต์ฌ ๋ก์ง: ํตํฉ ์๋ต ์์ฑ ํจ์ --- | |
| def generate_response(prompt_template: str, message: str, file: Optional = None): | |
| """ํ ์คํธ์ ์ด๋ฏธ์ง๋ฅผ ๋ชจ๋ ์ฒ๋ฆฌํ๋ ํตํฉ ์๋ต ์์ฑ ํจ์""" | |
| if not MODEL_LOADED: | |
| return "โ ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. ๊ด๋ฆฌ์์๊ฒ ๋ฌธ์ํ์ธ์." | |
| try: | |
| # 1. ํ์ผ ์ฒ๋ฆฌ | |
| file_text, pil_image = process_uploaded_file(file) | |
| # 2. ์ ์ฒด ํ๋กฌํํธ ๊ตฌ์ฑ | |
| full_message = message | |
| if file_text: | |
| full_message += f"\n\n[์ฒจ๋ถ ํ์ผ ๋ด์ฉ]\n{file_text}" | |
| full_prompt = prompt_template.format(message=full_message) | |
| # 3. ํ ํฌ๋์ด์ ๋ก ํ ์คํธ ์ ๋ ฅ ๋ณํ | |
| inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
| # 4. ์์ฑ ํ๋ผ๋ฏธํฐ ์ค๋น | |
| generation_args = { | |
| "max_new_tokens": 512, | |
| "temperature": 0.7, | |
| "do_sample": True, | |
| "pad_token_id": tokenizer.eos_token_id | |
| } | |
| # 5. ์ด๋ฏธ์ง๊ฐ ์๋ ๊ฒฝ์ฐ, ๋ฉํฐ๋ชจ๋ฌ ์ ๋ ฅ ์ถ๊ฐ | |
| if pil_image: | |
| print("๐ผ๏ธ ์ด๋ฏธ์ง ํฌํจ, ๋ฉํฐ๋ชจ๋ฌ ๋ชจ๋๋ก ์์ฑ") | |
| # KananaV ๋ชจ๋ธ์ ๋ง๋ ํํ๋ก ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ | |
| # (๋ชจ๋ธ์ ์๊ตฌ์ฌํญ์ ๋ฐ๋ผ ์ด ๋ถ๋ถ์ ๋ฌ๋ผ์ง ์ ์์ต๋๋ค) | |
| pixel_values = model.vision_model.image_processor(pil_image, return_tensors='pt')['pixel_values'] | |
| generation_args["pixel_values"] = pixel_values.to(model.device, dtype=torch.float16) | |
| else: | |
| print("๐ ํ ์คํธ๋ง์ผ๋ก ์์ฑ") | |
| # 6. ๋ชจ๋ธ์ ํตํด ์๋ต ์์ฑ (๋จ ํ ๋ฒ์ ์ฌ๋ฐ๋ฅธ ํธ์ถ) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, **generation_args) | |
| # 7. ์์ฑ๋ ํ ํฐ ID๋ฅผ ํ ์คํธ๋ก ๋์ฝ๋ฉ | |
| # ์ ๋ ฅ ํ๋กฌํํธ ๋ถ๋ถ์ ์ ์ธํ๊ณ ์์ํ ๋ต๋ณ๋ง ์ถ์ถ | |
| input_length = inputs["input_ids"].shape[1] | |
| response_ids = outputs[0][input_length:] | |
| response = tokenizer.decode(response_ids, skip_special_tokens=True).strip() | |
| return response | |
| except Exception as e: | |
| print(f"โ ์๋ต ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| traceback.print_exc() | |
| return f"์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}" | |
| # --- 5. Gradio UI ๋ฐ ์คํ --- | |
| with gr.Blocks(title="Lily Math RAG System", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ๐งฎ Lily Math RAG System") | |
| gr.Markdown("์ํ ๋ฌธ์ ํด๊ฒฐ ๋ฐ ๋ฉํฐ๋ชจ๋ฌ ๋ํ๋ฅผ ์ํ AI ์์คํ ์ ๋๋ค.") | |
| with gr.Tabs(): | |
| with gr.Tab("๐ฌ ์ฑํ "): | |
| chat_prompt = "<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" | |
| chatbot = gr.Chatbot(height=500, label="๋ํ์ฐฝ", type="messages") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| msg = gr.Textbox(label="๋ฉ์์ง", placeholder="์ด๋ฏธ์ง๋ PDF๋ฅผ ์ฒจ๋ถํ๊ณ ์ง๋ฌธํด๋ณด์ธ์!", lines=3, show_label=False) | |
| with gr.Column(scale=1, min_width=150): | |
| file_input = gr.File(label="ํ์ผ ์ ๋ก๋", file_types=[".pdf", ".png", ".jpg", ".jpeg"]) | |
| def respond(message, chat_history, file): | |
| bot_message = generate_response(chat_prompt, message, file) | |
| chat_history.append({"role": "user", "content": message}) | |
| chat_history.append({"role": "assistant", "content": bot_message}) | |
| return "", chat_history | |
| msg.submit(respond, [msg, chatbot, file_input], [msg, chatbot]) | |
| with gr.Tab("โ๏ธ ์์คํ ์ ๋ณด"): | |
| gr.Markdown(f"**๋ชจ๋ธ**: `{MODEL_NAME}`") | |
| gr.Markdown(f"**๋ชจ๋ธ ์ํ**: `{'โ ๋ก๋๋จ' if MODEL_LOADED else 'โ ๋ก๋ ์คํจ'}`") | |
| if __name__ == "__main__": | |
| # share=True๋ฅผ ์ฌ์ฉํ๋ฉด ์ธ๋ถ์์๋ ์ ์ ๊ฐ๋ฅํ ๊ณต๊ฐ ๋งํฌ๊ฐ ์์ฑ๋ฉ๋๋ค. | |
| demo.launch(share=True) |