lily-math-rag / app.py
gbrabbit's picture
Auto commit at 07-2025-08 4:43:48
b9ecb65
raw
history blame
7.29 kB
import gradio as gr
import os
import traceback
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import fitz # PyMuPDF
from PIL import Image
import io
# --- 1. ์ „์—ญ ๋ณ€์ˆ˜ ๋ฐ ํ™˜๊ฒฝ ์„ค์ • ---
tokenizer = None
model = None
MODEL_LOADED = False
# .env ํŒŒ์ผ์—์„œ ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ๋กœ๋“œ (์ฃผ๋กœ ๋กœ์ปฌ์—์„œ ์‚ฌ์šฉ)
try:
from dotenv import load_dotenv
load_dotenv()
print("โœ… .env ํŒŒ์ผ ๋กœ๋“œ๋จ")
except ImportError:
print("โš ๏ธ python-dotenv๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์Œ, ์‹œ์Šคํ…œ ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์‚ฌ์šฉ")
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ํ† ํฐ ๋ฐ ๋ชจ๋ธ ์ด๋ฆ„ ๊ฐ€์ ธ์˜ค๊ธฐ
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_NAME = os.getenv("MODEL_NAME", "gbrabbit/lily-math-model")
print(f"๐Ÿ” ๋ชจ๋ธ: {MODEL_NAME}")
print(f"๐Ÿ” HF ํ† ํฐ: {'โœ… ์„ค์ •๋จ'if HF_TOKEN else 'โŒ ์„ค์ •๋˜์ง€ ์•Š์Œ'}")
# --- 2. ํ•ต์‹ฌ ๋กœ์ง: ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ---
try:
print("๐Ÿ”ง ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์‹œ์ž‘...")
# ์ปค์Šคํ…€ ๋ชจ๋ธ ํด๋ž˜์Šค import
from modeling import KananaVForConditionalGeneration
if HF_TOKEN:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
token=HF_TOKEN,
trust_remote_code=True
)
model = KananaVForConditionalGeneration.from_pretrained(
MODEL_NAME,
token=HF_TOKEN,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto" # GPU ์ž๋™ ํ• ๋‹น (์„œ๋ฒ„ ํ™˜๊ฒฝ์— ํ•„์ˆ˜)
)
MODEL_LOADED = True
print("โœ… ์ปค์Šคํ…€ ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
else:
print("โš ๏ธ HF ํ† ํฐ์ด ์—†์–ด ๊ณต๊ฐœ ๋ชจ๋ธ(DialoGPT)๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค.")
MODEL_NAME = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
MODEL_LOADED = True
except Exception as e:
print(f"โŒ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
traceback.print_exc()
MODEL_LOADED = False
# --- 3. ํŒŒ์ผ ์ฒ˜๋ฆฌ ์œ ํ‹ธ๋ฆฌํ‹ฐ ---
def extract_text_from_pdf(pdf_file):
try:
doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
text = "".join(page.get_text() for page in doc)
doc.close()
return text
except Exception as e:
print(f"PDF ์ฒ˜๋ฆฌ ์˜ค๋ฅ˜: {e}")
return f"PDF ํŒŒ์ผ์„ ์ฝ๋Š” ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
def process_uploaded_file(file):
"""์—…๋กœ๋“œ๋œ ํŒŒ์ผ์„ ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€ ๊ฐ์ฒด๋กœ ๋ถ„๋ฆฌ"""
if file is None:
return "", None # ํ…์ŠคํŠธ, ์ด๋ฏธ์ง€ ์—†์Œ
file_path = file.name
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.pdf':
text_content = extract_text_from_pdf(file)
return text_content, None # PDF๋Š” ํ…์ŠคํŠธ๋งŒ, ์ด๋ฏธ์ง€๋Š” ์—†์Œ
elif file_extension in ['.png', '.jpg', '.jpeg']:
image = Image.open(file).convert('RGB')
# ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ž์ฒด๋ฅผ ๋ฐ˜ํ™˜ (OCR ๋Œ€์‹  ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉ)
return "์—…๋กœ๋“œ๋œ ์ด๋ฏธ์ง€๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.", image
else:
return f"์ง€์›ํ•˜์ง€ ์•Š๋Š” ํŒŒ์ผ ํ˜•์‹: {file_extension}", None
# --- 4. ํ•ต์‹ฌ ๋กœ์ง: ํ†ตํ•ฉ ์‘๋‹ต ์ƒ์„ฑ ํ•จ์ˆ˜ ---
def generate_response(prompt_template: str, message: str, file: Optional = None):
"""ํ…์ŠคํŠธ์™€ ์ด๋ฏธ์ง€๋ฅผ ๋ชจ๋‘ ์ฒ˜๋ฆฌํ•˜๋Š” ํ†ตํ•ฉ ์‘๋‹ต ์ƒ์„ฑ ํ•จ์ˆ˜"""
if not MODEL_LOADED:
return "โŒ ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๊ด€๋ฆฌ์ž์—๊ฒŒ ๋ฌธ์˜ํ•˜์„ธ์š”."
try:
# 1. ํŒŒ์ผ ์ฒ˜๋ฆฌ
file_text, pil_image = process_uploaded_file(file)
# 2. ์ „์ฒด ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
full_message = message
if file_text:
full_message += f"\n\n[์ฒจ๋ถ€ ํŒŒ์ผ ๋‚ด์šฉ]\n{file_text}"
full_prompt = prompt_template.format(message=full_message)
# 3. ํ† ํฌ๋‚˜์ด์ €๋กœ ํ…์ŠคํŠธ ์ž…๋ ฅ ๋ณ€ํ™˜
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
# 4. ์ƒ์„ฑ ํŒŒ๋ผ๋ฏธํ„ฐ ์ค€๋น„
generation_args = {
"max_new_tokens": 512,
"temperature": 0.7,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id
}
# 5. ์ด๋ฏธ์ง€๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ, ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ์ž…๋ ฅ ์ถ”๊ฐ€
if pil_image:
print("๐Ÿ–ผ๏ธ ์ด๋ฏธ์ง€ ํฌํ•จ, ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ๋ชจ๋“œ๋กœ ์ƒ์„ฑ")
# KananaV ๋ชจ๋ธ์— ๋งž๋Š” ํ˜•ํƒœ๋กœ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
# (๋ชจ๋ธ์˜ ์š”๊ตฌ์‚ฌํ•ญ์— ๋”ฐ๋ผ ์ด ๋ถ€๋ถ„์€ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค)
pixel_values = model.vision_model.image_processor(pil_image, return_tensors='pt')['pixel_values']
generation_args["pixel_values"] = pixel_values.to(model.device, dtype=torch.float16)
else:
print("๐Ÿ“„ ํ…์ŠคํŠธ๋งŒ์œผ๋กœ ์ƒ์„ฑ")
# 6. ๋ชจ๋ธ์„ ํ†ตํ•ด ์‘๋‹ต ์ƒ์„ฑ (๋‹จ ํ•œ ๋ฒˆ์˜ ์˜ฌ๋ฐ”๋ฅธ ํ˜ธ์ถœ)
with torch.no_grad():
outputs = model.generate(**inputs, **generation_args)
# 7. ์ƒ์„ฑ๋œ ํ† ํฐ ID๋ฅผ ํ…์ŠคํŠธ๋กœ ๋””์ฝ”๋”ฉ
# ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ ๋ถ€๋ถ„์„ ์ œ์™ธํ•˜๊ณ  ์ˆœ์ˆ˜ํ•œ ๋‹ต๋ณ€๋งŒ ์ถ”์ถœ
input_length = inputs["input_ids"].shape[1]
response_ids = outputs[0][input_length:]
response = tokenizer.decode(response_ids, skip_special_tokens=True).strip()
return response
except Exception as e:
print(f"โŒ ์‘๋‹ต ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
traceback.print_exc()
return f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
# --- 5. Gradio UI ๋ฐ ์‹คํ–‰ ---
with gr.Blocks(title="Lily Math RAG System", theme=gr.themes.Soft()) as demo:
gr.Markdown("# ๐Ÿงฎ Lily Math RAG System")
gr.Markdown("์ˆ˜ํ•™ ๋ฌธ์ œ ํ•ด๊ฒฐ ๋ฐ ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ๋Œ€ํ™”๋ฅผ ์œ„ํ•œ AI ์‹œ์Šคํ…œ์ž…๋‹ˆ๋‹ค.")
with gr.Tabs():
with gr.Tab("๐Ÿ’ฌ ์ฑ„ํŒ…"):
chat_prompt = "<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
chatbot = gr.Chatbot(height=500, label="๋Œ€ํ™”์ฐฝ", type="messages")
with gr.Row():
with gr.Column(scale=4):
msg = gr.Textbox(label="๋ฉ”์‹œ์ง€", placeholder="์ด๋ฏธ์ง€๋‚˜ PDF๋ฅผ ์ฒจ๋ถ€ํ•˜๊ณ  ์งˆ๋ฌธํ•ด๋ณด์„ธ์š”!", lines=3, show_label=False)
with gr.Column(scale=1, min_width=150):
file_input = gr.File(label="ํŒŒ์ผ ์—…๋กœ๋“œ", file_types=[".pdf", ".png", ".jpg", ".jpeg"])
def respond(message, chat_history, file):
bot_message = generate_response(chat_prompt, message, file)
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": bot_message})
return "", chat_history
msg.submit(respond, [msg, chatbot, file_input], [msg, chatbot])
with gr.Tab("โš™๏ธ ์‹œ์Šคํ…œ ์ •๋ณด"):
gr.Markdown(f"**๋ชจ๋ธ**: `{MODEL_NAME}`")
gr.Markdown(f"**๋ชจ๋ธ ์ƒํƒœ**: `{'โœ… ๋กœ๋“œ๋จ' if MODEL_LOADED else 'โŒ ๋กœ๋“œ ์‹คํŒจ'}`")
if __name__ == "__main__":
# share=True๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์™ธ๋ถ€์—์„œ๋„ ์ ‘์† ๊ฐ€๋Šฅํ•œ ๊ณต๊ฐœ ๋งํฌ๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.
demo.launch(share=True)