Spaces:
Sleeping
Sleeping
File size: 8,749 Bytes
e3f9de3 a796dd8 8f7bcc7 7964ad2 a796dd8 e3f9de3 0553b33 7964ad2 a796dd8 e3f9de3 d3654f8 e3f9de3 d3654f8 abb24e2 2100340 abb24e2 2100340 7964ad2 a796dd8 abb24e2 4056037 abb24e2 a796dd8 7964ad2 abb24e2 b9ecb65 e3f9de3 7964ad2 e3f9de3 abb24e2 e3f9de3 7964ad2 abb24e2 a796dd8 b9ecb65 8f7bcc7 a796dd8 e3f9de3 7964ad2 0553b33 7964ad2 b9ecb65 0553b33 b9ecb65 0553b33 7964ad2 e3f9de3 a796dd8 e3f9de3 7964ad2 e3f9de3 7964ad2 e3f9de3 7964ad2 b9ecb65 7964ad2 e3f9de3 7964ad2 e3f9de3 7964ad2 b9ecb65 e3f9de3 b9ecb65 a796dd8 7964ad2 e3f9de3 7964ad2 e3f9de3 a796dd8 e3f9de3 a796dd8 e3f9de3 7964ad2 e3f9de3 a796dd8 b9ecb65 e3f9de3 7964ad2 e3f9de3 7964ad2 0f0528f e3f9de3 7964ad2 e3f9de3 15846c7 7964ad2 15846c7 e3f9de3 b9ecb65 abb24e2 b9ecb65 a796dd8 e3f9de3 abb24e2 e3f9de3 abb24e2 959d547 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# ํ์ผ: app.py (์ต์ข
์์ ๋ณธ)
import gradio as gr
import os
import traceback
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoImageProcessor
import torch
import fitz
from PIL import Image
from typing import Optional, List
# --- 1 & 2. ์ ์ญ ๋ณ์, ํ๊ฒฝ ์ค์ , ๋ชจ๋ธ ๋ก๋ฉ (๊ธฐ์กด ์ฝ๋์ ๋์ผ) ---
# (์ด ๋ถ๋ถ์ ์์ ํ ํ์ ์์ด ๊ทธ๋๋ก ๋์๋ฉด ๋ฉ๋๋ค)
# ... (์๋ต) ...
# --- 1 & 2. ์ ์ญ ๋ณ์, ํ๊ฒฝ ์ค์ , ๋ชจ๋ธ ๋ก๋ฉ (๊ธฐ์กด ์ฝ๋์ ๋์ผ) ---
tokenizer = None
model = None
image_processor = None
MODEL_LOADED = False
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IS_LOCAL = os.path.exists('.env') or os.path.exists('../.env') or os.getenv('IS_LOCAL') == 'true'
try:
from dotenv import load_dotenv
if IS_LOCAL:
load_dotenv()
print("โ
.env ํ์ผ ๋ก๋๋จ")
except ImportError:
print("โ ๏ธ python-dotenv๊ฐ ์ค์น๋์ง ์์")
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_NAME_SERVER = os.getenv("MODEL_NAME", "gbrabbit/lily-math-model")
MODEL_PATH_LOCAL = "../lily_llm_core/models/kanana_1_5_v_3b_instruct"
MODEL_PATH = MODEL_PATH_LOCAL if IS_LOCAL else MODEL_NAME_SERVER
print(f"============== ์์คํ
ํ๊ฒฝ ์ ๋ณด ==============")
print(f"๐ ์คํ ํ๊ฒฝ: {'๋ก์ปฌ' if IS_LOCAL else '์๋ฒ'}")
print(f"๐ ๋ชจ๋ธ ๊ฒฝ๋ก: {MODEL_PATH}")
print(f"๐ ์ฌ์ฉ ๋๋ฐ์ด์ค: {DEVICE.upper()}")
print("==========================================")
try:
print("๐ง ๋ชจ๋ธ ๋ก๋ฉ ์์...")
from modeling import KananaVForConditionalGeneration
if IS_LOCAL:
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"๋ก์ปฌ ๋ชจ๋ธ ๊ฒฝ๋ก๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค: {MODEL_PATH}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True, local_files_only=True)
model = KananaVForConditionalGeneration.from_pretrained(
MODEL_PATH, torch_dtype=torch.bfloat16, trust_remote_code=True, local_files_only=True,
).to(DEVICE)
image_processor = AutoImageProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True, local_files_only=True)
print("โ
๋ก์ปฌ ๋ชจ๋ธ ๋ฐ ์ด๋ฏธ์ง ํ๋ก์ธ์ ๋ก๋ฉ ์๋ฃ!")
else:
if not HF_TOKEN:
raise ValueError("์๋ฒ ํ๊ฒฝ์์๋ Hugging Face ํ ํฐ(HF_TOKEN)์ด ๋ฐ๋์ ํ์ํฉ๋๋ค.")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=HF_TOKEN, trust_remote_code=True)
model = KananaVForConditionalGeneration.from_pretrained(
MODEL_PATH, token=HF_TOKEN, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto"
)
image_processor = AutoImageProcessor.from_pretrained(MODEL_PATH, token=HF_TOKEN, trust_remote_code=True)
print("โ
์๋ฒ ๋ชจ๋ธ ๋ฐ ์ด๋ฏธ์ง ํ๋ก์ธ์ ๋ก๋ฉ ์๋ฃ!")
MODEL_LOADED = True
except Exception as e:
print(f"โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}")
traceback.print_exc()
MODEL_LOADED = False
# --- 3. ์๋ต ์์ฑ ๋ก์ง (๊ธฐ์กด ์ฝ๋์ ๋์ผ) ---
def extract_text_from_pdf(pdf_file_path):
try:
doc = fitz.open(pdf_file_path)
text = "".join(page.get_text() for page in doc)
doc.close()
return text
except Exception as e:
print(f"PDF ์ฒ๋ฆฌ ์ค๋ฅ: {e}")
return f"PDF ํ์ผ์ ์ฝ๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}"
def generate_response(prompt_template: str, message: str, files: Optional[List] = None):
if not MODEL_LOADED: return "โ ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค."
try:
all_pixel_values, all_image_metas, file_texts = [], [], []
if files:
for file in files:
file_path, file_extension = file.name, os.path.splitext(file.name)[1].lower()
if file_extension == '.pdf': file_texts.append(extract_text_from_pdf(file_path))
elif file_extension in ['.png', '.jpg', '.jpeg']:
pil_image = Image.open(file_path).convert('RGB')
processed_data = image_processor(pil_image)
all_pixel_values.append(processed_data["pixel_values"])
all_image_metas.append(processed_data["image_meta"])
image_tokens = "<image>" * len(all_pixel_values)
pdf_content = "\n\n".join(file_texts)
full_message = message + (f"\n{image_tokens}" if image_tokens else "") + (f"\n\n[์ฒจ๋ถ๋ PDF ๋ด์ฉ]:\n{pdf_content}" if pdf_content else "")
full_prompt = prompt_template.format(message=full_message)
if all_image_metas:
combined_metas = {key: [meta[key] for meta in all_image_metas] for key in all_image_metas[0]}
inputs = tokenizer.encode_prompt(prompt=full_prompt, image_meta=combined_metas)
inputs = {k: (v.unsqueeze(0).to(model.device) if torch.is_tensor(v) else v) for k, v in inputs.items()}
else:
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
generation_args = {
"max_new_tokens": 32,
"temperature": 0.8,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"top_p": 0.95,
}
with torch.no_grad():
if all_pixel_values:
outputs = model.generate(**inputs, pixel_values=all_pixel_values, image_metas=combined_metas, **generation_args)
else:
outputs = model.generate(**inputs, **generation_args)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("<|im_start|>assistant\n")[-1].strip()
except Exception as e:
print(f"โ ์๋ต ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}"); traceback.print_exc(); return f"์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}"
# --- 4. Gradio UI ๋ฐ ์คํ (์ต์ข
์์ ) ---
with gr.Blocks(title="Lily LLM System", theme=gr.themes.Soft()) as demo:
gr.Markdown("# ๐งฎ Lily LLM System")
gr.Markdown("์ด๋ฏธ์ง, PDF, ํ
์คํธ๋ฅผ ์ดํดํ๊ณ ๋ต๋ณํ๋ ๋ฉํฐ๋ชจ๋ฌ AI ์์คํ
์
๋๋ค.")
with gr.Tabs():
with gr.Tab("๐ฌ ์ฑํ
"):
chat_prompt = "<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
chatbot = gr.Chatbot(height=320, label="๋ํ์ฐฝ", elem_id="chatbot", type="messages")
with gr.Row():
msg = gr.Textbox(label="๋ฉ์์ง ์
๋ ฅ", placeholder="๋ฉ์์ง๋ฅผ ์
๋ ฅํ์ธ์", lines=3, show_label=False, scale=4)
file_input = gr.File(label="ํ์ผ ์
๋ก๋", file_count="multiple", file_types=[".pdf", ".png", ".jpg", ".jpeg"], scale=1)
send_btn = gr.Button("์ ์ก", variant="primary", scale=1)
# โ
1. respond ํจ์๊ฐ 'files'๋ฅผ ์ธ ๋ฒ์งธ ์ธ์๋ก ๋ฐ๋๋ก ์์
def respond(message, chat_history, files):
if not message.strip() and not files:
return "", chat_history, None # files ์ถ๋ ฅ๋ ๋น์์ค
bot_message = generate_response(chat_prompt, message, files)
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": bot_message})
# โ
2. ์ถ๋ ฅ์ ๊ฐ์๋ฅผ inputs์ ๋ง์ถ๊ธฐ ์ํด file_input๋ ๋ฐํ๊ฐ์ ์ถ๊ฐ
return "", chat_history, None
# โ
3. click๊ณผ submit์ inputs ๋ฆฌ์คํธ์ 'file_input' ์ถ๊ฐ
send_btn.click(
respond,
inputs=[msg, chatbot, file_input],
outputs=[msg, chatbot, file_input], # ์ถ๋ ฅ์๋ file_input ์ถ๊ฐ
api_name="chat", # api_name์ ์ฌ๋์ ์์ด ์ฌ์ฉ
# queue=False
)
msg.submit(
respond,
inputs=[msg, chatbot, file_input],
outputs=[msg, chatbot, file_input], # ์ถ๋ ฅ์๋ file_input ์ถ๊ฐ
api_name="chat",
# queue=False
)
with gr.Tab("โ๏ธ ์์คํ
์ ๋ณด"):
gr.Markdown(f"**์คํ ํ๊ฒฝ**: `{'๋ก์ปฌ' if IS_LOCAL else '์๋ฒ'}`")
gr.Markdown(f"**๋ชจ๋ธ ๊ฒฝ๋ก**: `{MODEL_PATH}`")
gr.Markdown(f"**๋ชจ๋ธ ์ํ**: `{'โ
๋ก๋๋จ' if MODEL_LOADED else 'โ ๋ก๋ ์คํจ'}`")
if __name__ == "__main__":
if IS_LOCAL:
print("\n๐ ๋ก์ปฌ ์๋ฒ๋ฅผ ์์ํฉ๋๋ค. http://127.0.0.1:8006")
demo.launch(server_name="127.0.0.1", server_port=8006, share=False)
else:
print("\n๐ ์๋ฒ๋ฅผ ์์ํฉ๋๋ค...")
demo.launch() |