import gradio as gr from PIL import Image import numpy as np import onnxruntime as rt import os import spaces import torch from transformers import AutoImageProcessor, AutoModelForCausalLM, AutoTokenizer # NEW: LLM Imports from scipy.special import softmax # 1. MODEL CONFIGURATION AND LOADING # ---------------------------------------------------- # 1.1 ONNX Model (Image Classifier) Configuration ONNX_MODEL_PATH = "model.onnx" CLASS_LABELS_FILE = "class_labels.txt" MODEL_ID = 'facebook/convnext-tiny-224' # 1.2 LLM Configuration (Loaded Locally) LLM_MODEL_NAME = "scb10x/typhoon2.5-qwen3-4b" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"LLM Device: {device}") # Load ONNX Runtime Session and LLM try: # 1. Load ONNX Runtime (ConvNeXt) if not os.path.exists(ONNX_MODEL_PATH): raise FileNotFoundError(f"ONNX Model file not found at: {ONNX_MODEL_PATH}") print(f"Attempting to load ONNX model from: {ONNX_MODEL_PATH}") sess = rt.InferenceSession(ONNX_MODEL_PATH) onnx_input_name = sess.get_inputs()[0].name onnx_output_name = sess.get_outputs()[0].name processor = AutoImageProcessor.from_pretrained(MODEL_ID) print("ONNX model and Image Processor loaded successfully.") # 2. Load LLM (Typhoon 2.5) Locally print(f"Attempting to load LLM model: {LLM_MODEL_NAME} onto {device}...") llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True) llm_model = AutoModelForCausalLM.from_pretrained( LLM_MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16, # Use float16 for efficiency low_cpu_mem_usage=True, ) llm_model.to(device) print("Typhoon 2.5 LLM loaded successfully.") except Exception as e: print(f"FATAL ERROR LOADING MODELS: {e}") print("Please ensure GPU is available and files are uploaded correctly (including model.onnx.data).") sess = None llm_model = None llm_tokenizer = None # Load character classes try: with open(CLASS_LABELS_FILE, 'r', encoding='utf-8') as f: CHARACTER_LABELS = [line.strip() for line in f.readlines()] except FileNotFoundError: CHARACTER_LABELS = [ 'Ace', 'Akainu', 'Brook', 'Chopper', 'Crocodile', 'Franky', 'Jinbei', 'Kurohige', 'Law', 'Luffy', 'Mihawk', 'Nami', 'Rayleigh', 'Robin', 'Sanji', 'Shanks', 'Usopp', 'Zoro' ] # 2. LLM GENERATION FUNCTION (Local Inference) # ---------------------------------------------------- # ฐานข้อมูลข้อมูลเสริมตัวละคร CHARACTER_INFO = { "Ace": "โพโทกัส ดี เอส พี่ชายบุญธรรมของลูฟี่ ผู้ใช้พลังผลปีศาจเมระ เมระ", "Akainu": "อาคาอินุ (ซาคาสุกิ) พลเรือเอกกองทัพเรือ ผู้ใช้พลังผลปีศาจมากุ มากุ (แม็กม่า)", "Brook": "บรู๊ค นักดนตรีผู้ใช้ดาบและมีชีวิตเป็นโครงกระดูก ผู้รักการร้องเพลงและมุกตลก", "Chopper": "โทนี่ โทนี่ ช็อปเปอร์ หมอประจำเรือ ผู้มีใจรักเพื่อนและอ่อนไหวที่สุดในกลุ่ม", "Crocodile": "เซอร์ คร็อกโคไดล์ อดีตเจ็ดเทพโจรสลัด ผู้ใช้พลังผลปีศาจซึนะ ซึนะ (ทราย)", "Franky": "แฟรงกี้ ช่างต่อเรือผู้สร้างเรือเธาซันด์ ซันนี่ มีพลังไซบอร์กสุดแกร่ง", "Jinbei": "จินเบ อดีตเจ็ดเทพโจรสลัด และเป็นมนุษย์เงือกผู้เชี่ยวชาญคาราเต้เงือก", "Kurohige": "มาร์แชล ดี ทีช หรือหนวดดำ ผู้เป็นหนึ่งในสี่จักรพรรดิ์คนปัจจุบัน", "Law": "ทราฟาลการ์ ลอว์ กัปตันกลุ่มโจรสลัดฮาร์ท ผู้ใช้พลังผลโอเปะ โอเปะ", "Luffy": "มังกี้ ดี ลูฟี่ กัปตันกลุ่มโจรสลัดหมวกฟาง ผู้ใฝ่ฝันจะเป็นราชาโจรสลัด", "Mihawk": "จูราคิล มิฮอว์ค สุดยอดนักดาบผู้เป็นที่มาของฉายา 'ตาเหยี่ยว'", "Nami": "นามิ นักเดินเรือสาวแห่งกลุ่มโจรสลัดหมวกฟาง และเป็นนักทำแผนที่มือฉมัง", "Rayleigh": "ซิลเวอร์ส เรย์ลี่ อดีตมือขวาของราชาโจรสลัด โกลด์ ดี. โรเจอร์", "Robin": "นิโค โรบิน นักโบราณคดี ผู้เดียวที่อ่านโพเนกลีฟได้", "Sanji": "ซันจิ กุ๊กแห่งกลุ่มโจรสลัดหมวกฟาง และเป็นสุดยอดนักสู้ที่ใช้เท้าในการต่อสู้", "Shanks": "แชงค์ส หนึ่งในสี่จักรพรรดิ์ ผู้มอบหมวกฟางให้กับลูฟี่", "Usopp": "อุซป พลซุ่มยิงและนักประดิษฐ์ ผู้มีความฝันเป็นนักรบผู้กล้าหาญแห่งท้องทะเล", "Zoro": "โรโรโนอา โซโล นักดาบสามเล่มแห่งกลุ่มหมวกฟาง ผู้มีเป้าหมายเป็นนักดาบอันดับหนึ่งของโลก", } def generate_typhoon_response(character_name, confidence): """ ฟังก์ชัน LLM ที่ใช้ Local Inference ภายใน Space """ if llm_model is None: return (f"❌ LLM ไม่พร้อมใช้งาน: ตัวละครคือ **{character_name}** " f"[ความมั่นใจ: **{confidence*100:.2f}%**]") info = CHARACTER_INFO.get(character_name, "ตัวละครวันพีซ") # 1. Build a clear, instructional prompt for the LLM prompt = ( f"จากผลการวิเคราะห์ภาพ (ความมั่นใจ {confidence*100:.2f}%), ตัวละครที่ทำนายคือ '{character_name}'. " f"ตัวละครนี้คือ {info}. " f"กรุณาสร้างข้อความตอบกลับที่เป็นมิตรและเป็นภาษาไทย โดยขึ้นต้นด้วย 'ยืนยันผลการทำนาย!' " f"และรวมข้อมูลทั้งหมดนี้เข้าด้วยกันในประโยคเดียวโดยใช้ Markdown bold สำหรับชื่อตัวละครและความมั่นใจ (XX.XX%)." ) # 2. Generate text using the local LLM messages = [{"role": "user", "content": prompt}] input_ids = llm_tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(device) output_ids = llm_model.generate( input_ids, max_new_tokens=256, temperature=0.7, do_sample=True, pad_token_id=llm_tokenizer.eos_token_id, ) # 3. Decode response response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True) # Remove the input prompt from the response response_text = response.split(prompt)[-1].strip() return response_text # 3. ONNX INFERENCE FUNCTION # ---------------------------------------------------- # เราจะใช้ @spaces.GPU ตรงนี้เพื่อให้ LLM (ซึ่งอยู่ในฟังก์ชันที่ถูกเรียก) รันบน GPU ด้วย @spaces.GPU # ใช้ GPU สำหรับการรัน LLM def predict_one_piece_character(pil_image): if pil_image is None or sess is None: return "⚠️ โมเดลไม่พร้อมใช้งาน กรุณาตรวจสอบไฟล์ ONNX และการตั้งค่า" try: # 3.1 Preprocessing (ConvNeXt standard input) inputs = processor(images=pil_image, return_tensors="np") onnx_input = inputs['pixel_values'].astype(np.float32) # 3.2 Run Inference (ConvNeXt ONNX) onnx_predictions = sess.run([onnx_output_name], {onnx_input_name: onnx_input}) logits = onnx_predictions[0].squeeze() # 3.3 Post-processing (Softmax and Argmax) probabilities = softmax(logits) predicted_index = np.argmax(probabilities) predicted_character = CHARACTER_LABELS[predicted_index] confidence = probabilities[predicted_index].item() # 3.4 LLM Integration (Local Generation) final_response = generate_typhoon_response(predicted_character, confidence) return final_response except Exception as e: print(f"RUNTIME ERROR: {e}") return f"เกิดข้อผิดพลาดในการทำนาย: {e}" # 4. GRADIO INTERFACE # ---------------------------------------------------- interface = gr.Interface( fn=predict_one_piece_character, inputs=gr.Image(type="pil", label="อัปโหลดรูปภาพตัวละครวันพีซ"), outputs=gr.Textbox(label="ผลการทำนายชื่อตัวละคร (Typhoon 2.5 Local)", lines=5), title="🏴‍☠️ One Piece Classifier (ConvNeXt ONNX + Typhoon 2.5 Local)", description="แอปพลิเคชันจำแนกตัวละครวันพีซโดยรัน LLM ภายใน Space" ) if __name__ == "__main__": interface.launch(inbrowser=True)