Spaces:
Running
Running
Commit
·
603c45b
1
Parent(s):
5756752
Add application file
Browse files
app.py
CHANGED
|
@@ -5,79 +5,69 @@ import onnxruntime as rt
|
|
| 5 |
import os
|
| 6 |
import spaces
|
| 7 |
import torch
|
| 8 |
-
|
| 9 |
-
from
|
| 10 |
-
from scipy.special import softmax # Used for computing probabilities from model output
|
| 11 |
-
import requests # Used for making HTTP API calls (Typhoon 2.5)
|
| 12 |
|
| 13 |
-
# 1.
|
| 14 |
# ----------------------------------------------------
|
|
|
|
| 15 |
ONNX_MODEL_PATH = "model.onnx"
|
| 16 |
CLASS_LABELS_FILE = "class_labels.txt"
|
| 17 |
-
MODEL_ID = 'facebook/convnext-tiny-224'
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
API_URL = f"https://router.huggingface.co/models/{MODEL_NAME}" # <--- แก้ไข URL เป็น router.huggingface.co
|
| 24 |
|
| 25 |
-
|
| 26 |
-
# Load character classes from the file created during training
|
| 27 |
try:
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
print(f"
|
| 33 |
-
|
| 34 |
-
# Load ONNX Runtime Session
|
| 35 |
-
try:
|
| 36 |
-
print(f"Attempting to load ONNX model from: {ONNX_MODEL_PATH}") # <-- NEW DEBUG LINE
|
| 37 |
sess = rt.InferenceSession(ONNX_MODEL_PATH)
|
| 38 |
onnx_input_name = sess.get_inputs()[0].name
|
| 39 |
onnx_output_name = sess.get_outputs()[0].name
|
| 40 |
-
# Load Image Processor (essential for correct image preparation)
|
| 41 |
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
|
| 42 |
print("ONNX model and Image Processor loaded successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
except Exception as e:
|
| 44 |
-
|
| 45 |
-
print(
|
| 46 |
-
print("Please ensure model.onnx is tracked by Git LFS and is uploaded correctly.")
|
| 47 |
sess = None
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
return "Error: HF_TOKEN is not set in Hugging Face Space secrets."
|
| 56 |
-
|
| 57 |
-
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
|
| 58 |
-
response = requests.post(API_URL, headers=headers, json=payload)
|
| 59 |
-
|
| 60 |
-
# Check for non-successful status codes (e.g., 401 Unauthorized, 503 Service Unavailable)
|
| 61 |
-
if response.status_code != 200:
|
| 62 |
-
return f"Error {response.status_code}: API call failed. {response.text}"
|
| 63 |
-
|
| 64 |
-
try:
|
| 65 |
-
# Extract the generated text from the response structure
|
| 66 |
-
result = response.json()[0]['generated_text']
|
| 67 |
-
# Remove the input prompt part from the output text
|
| 68 |
-
return result.split(payload['inputs'])[-1].strip()
|
| 69 |
-
except Exception as e:
|
| 70 |
-
return f"Error processing API response: {e}"
|
| 71 |
|
| 72 |
|
| 73 |
-
#
|
| 74 |
-
#
|
| 75 |
-
#
|
| 76 |
CHARACTER_INFO = {
|
| 77 |
"Ace": "โพโทกัส ดี เอส พี่ชายบุญธรรมของลูฟี่ ผู้ใช้พลังผลปีศาจเมระ เมระ",
|
| 78 |
"Luffy": "มังกี้ ดี ลูฟี่ กัปตันกลุ่มโจรสลัดหมวกฟาง ผู้ใฝ่ฝันจะเป็นราชาโจรสลัด",
|
| 79 |
"Zoro": "โรโรโนอา โซโล นักดาบสามเล่มแห่งกลุ่มหมวกฟาง ผู้มีเป้าหมายเป็นนักดาบอันดับหนึ่งของโลก",
|
| 80 |
-
"Nami": "นามิ
|
| 81 |
"Sanji": "ซันจิ กุ๊กแห่งกลุ่มโจรสลัดหมวกฟาง และเป็นสุดยอดนักสู้ที่ใช้เท้าในการต่อสู้",
|
| 82 |
"Chopper": "โทนี่ โทนี่ ช็อปเปอร์ หมอประจำเรือ ผู้มีใจรักเพื่อนและอ่อนไหวที่สุดในกลุ่ม",
|
| 83 |
"Robin": "นิโค โรบิน นักโบราณคดี ผู้เดียวที่อ่านโพเนกลีฟได้",
|
|
@@ -92,11 +82,14 @@ CHARACTER_INFO = {
|
|
| 92 |
"Rayleigh": "ซิลเวอร์ส เรย์ลี่ อดีตมือขวาของราชาโจรสลัด โกลด์ ดี. โรเจอร์",
|
| 93 |
}
|
| 94 |
|
| 95 |
-
|
| 96 |
-
def generate_thai_response(character_name, confidence):
|
| 97 |
"""
|
| 98 |
-
|
| 99 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
info = CHARACTER_INFO.get(character_name, "ตัวละครวันพีซ")
|
| 101 |
|
| 102 |
# 1. Build a clear, instructional prompt for the LLM
|
|
@@ -106,51 +99,54 @@ def generate_thai_response(character_name, confidence):
|
|
| 106 |
f"กรุณาสร้างข้อความตอบกลับที่เป็นมิตรและเป็นภาษาไทย โดยขึ้นต้นด้วย 'ยืนยันผลการทำนาย!' "
|
| 107 |
f"และรวมข้อมูลทั้งหมดนี้เข้าด้วยกันในประโยคเดียวโดยใช้ Markdown bold สำหรับชื่อตัวละครและความมั่นใจ (XX.XX%)."
|
| 108 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
-
#
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
}
|
| 117 |
-
}
|
| 118 |
-
|
| 119 |
-
# 3. Call the API and handle potential errors
|
| 120 |
-
llm_response = query_typhoon_api(payload)
|
| 121 |
-
|
| 122 |
-
if llm_response.startswith("Error"):
|
| 123 |
-
# Fallback to a static, simple response if API fails
|
| 124 |
-
thai_name = info.split(' ')[0]
|
| 125 |
-
return (f"⚠️ LLM API ไม่ตอบสนอง: ตัวละครคือ **{thai_name}** ({info}) "
|
| 126 |
-
f"[ความมั่นใจ: **{confidence*100:.2f}%**]")
|
| 127 |
|
| 128 |
-
return llm_response
|
| 129 |
|
| 130 |
-
#
|
| 131 |
# ----------------------------------------------------
|
| 132 |
-
|
|
|
|
| 133 |
def predict_one_piece_character(pil_image):
|
| 134 |
if pil_image is None or sess is None:
|
| 135 |
return "⚠️ โมเดลไม่พร้อมใช้งาน กรุณาตรวจสอบไฟล์ ONNX และการตั้งค่า"
|
| 136 |
|
| 137 |
try:
|
| 138 |
-
#
|
| 139 |
inputs = processor(images=pil_image, return_tensors="np")
|
| 140 |
onnx_input = inputs['pixel_values'].astype(np.float32)
|
| 141 |
|
| 142 |
-
#
|
| 143 |
onnx_predictions = sess.run([onnx_output_name], {onnx_input_name: onnx_input})
|
| 144 |
logits = onnx_predictions[0].squeeze()
|
| 145 |
|
| 146 |
-
#
|
| 147 |
probabilities = softmax(logits)
|
| 148 |
predicted_index = np.argmax(probabilities)
|
| 149 |
predicted_character = CHARACTER_LABELS[predicted_index]
|
| 150 |
confidence = probabilities[predicted_index].item()
|
| 151 |
|
| 152 |
-
#
|
| 153 |
-
final_response =
|
| 154 |
|
| 155 |
return final_response
|
| 156 |
|
|
@@ -158,14 +154,14 @@ def predict_one_piece_character(pil_image):
|
|
| 158 |
print(f"RUNTIME ERROR: {e}")
|
| 159 |
return f"เกิดข้อผิดพลาดในการทำนาย: {e}"
|
| 160 |
|
| 161 |
-
#
|
| 162 |
# ----------------------------------------------------
|
| 163 |
interface = gr.Interface(
|
| 164 |
fn=predict_one_piece_character,
|
| 165 |
inputs=gr.Image(type="pil", label="อัปโหลดรูปภาพตัวละครวันพีซ"),
|
| 166 |
-
outputs=gr.Textbox(label="ผลการทำนายชื่อตัวละคร (
|
| 167 |
-
title="🏴☠️ One Piece Classifier (ConvNeXt ONNX + Typhoon 2.5)",
|
| 168 |
-
description="
|
| 169 |
)
|
| 170 |
|
| 171 |
if __name__ == "__main__":
|
|
|
|
| 5 |
import os
|
| 6 |
import spaces
|
| 7 |
import torch
|
| 8 |
+
from transformers import AutoImageProcessor, AutoModelForCausalLM, AutoTokenizer # NEW: LLM Imports
|
| 9 |
+
from scipy.special import softmax
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
# 1. MODEL CONFIGURATION AND LOADING
|
| 12 |
# ----------------------------------------------------
|
| 13 |
+
# 1.1 ONNX Model (Image Classifier) Configuration
|
| 14 |
ONNX_MODEL_PATH = "model.onnx"
|
| 15 |
CLASS_LABELS_FILE = "class_labels.txt"
|
| 16 |
+
MODEL_ID = 'facebook/convnext-tiny-224'
|
| 17 |
|
| 18 |
+
# 1.2 LLM Configuration (Loaded Locally)
|
| 19 |
+
LLM_MODEL_NAME = "scb10x/typhoon2.5-qwen3-4b"
|
| 20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
print(f"LLM Device: {device}")
|
|
|
|
| 22 |
|
| 23 |
+
# Load ONNX Runtime Session and LLM
|
|
|
|
| 24 |
try:
|
| 25 |
+
# 1. Load ONNX Runtime (ConvNeXt)
|
| 26 |
+
if not os.path.exists(ONNX_MODEL_PATH):
|
| 27 |
+
raise FileNotFoundError(f"ONNX Model file not found at: {ONNX_MODEL_PATH}")
|
| 28 |
+
|
| 29 |
+
print(f"Attempting to load ONNX model from: {ONNX_MODEL_PATH}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
sess = rt.InferenceSession(ONNX_MODEL_PATH)
|
| 31 |
onnx_input_name = sess.get_inputs()[0].name
|
| 32 |
onnx_output_name = sess.get_outputs()[0].name
|
|
|
|
| 33 |
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
|
| 34 |
print("ONNX model and Image Processor loaded successfully.")
|
| 35 |
+
|
| 36 |
+
# 2. Load LLM (Typhoon 2.5) Locally
|
| 37 |
+
print(f"Attempting to load LLM model: {LLM_MODEL_NAME} onto {device}...")
|
| 38 |
+
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True)
|
| 39 |
+
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 40 |
+
LLM_MODEL_NAME,
|
| 41 |
+
trust_remote_code=True,
|
| 42 |
+
torch_dtype=torch.float16, # Use float16 for efficiency
|
| 43 |
+
low_cpu_mem_usage=True,
|
| 44 |
+
)
|
| 45 |
+
llm_model.to(device)
|
| 46 |
+
print("Typhoon 2.5 LLM loaded successfully.")
|
| 47 |
+
|
| 48 |
except Exception as e:
|
| 49 |
+
print(f"FATAL ERROR LOADING MODELS: {e}")
|
| 50 |
+
print("Please ensure GPU is available and files are uploaded correctly (including model.onnx.data).")
|
|
|
|
| 51 |
sess = None
|
| 52 |
+
llm_model = None
|
| 53 |
+
llm_tokenizer = None
|
| 54 |
|
| 55 |
+
# Load character classes
|
| 56 |
+
try:
|
| 57 |
+
with open(CLASS_LABELS_FILE, 'r', encoding='utf-8') as f:
|
| 58 |
+
CHARACTER_LABELS = [line.strip() for line in f.readlines()]
|
| 59 |
+
except FileNotFoundError:
|
| 60 |
+
CHARACTER_LABELS = ['Luffy', 'Zoro', 'Nami', 'Sanji', 'Chopper', 'Franky', 'Brook', 'Usopp', 'Jinbei', 'Robin', 'Ace', 'Law', 'Shanks', 'Kurohige', 'Mihawk', 'Rayleigh']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
+
# 2. LLM GENERATION FUNCTION (Local Inference)
|
| 64 |
+
# ----------------------------------------------------
|
| 65 |
+
# ฐานข้อมูลข้อมูลเสริมตัวละคร
|
| 66 |
CHARACTER_INFO = {
|
| 67 |
"Ace": "โพโทกัส ดี เอส พี่ชายบุญธรรมของลูฟี่ ผู้ใช้พลังผลปีศาจเมระ เมระ",
|
| 68 |
"Luffy": "มังกี้ ดี ลูฟี่ กัปตันกลุ่มโจรสลัดหมวกฟาง ผู้ใฝ่ฝันจะเป็นราชาโจรสลัด",
|
| 69 |
"Zoro": "โรโรโนอา โซโล นักดาบสามเล่มแห่งกลุ่มหมวกฟาง ผู้มีเป้าหมายเป็นนักดาบอันดับหนึ่งของโลก",
|
| 70 |
+
"Nami": "นามิ นักเดินเรือสาวแห่งกลุ่มโจรสลัดหมวกฟาง และเป็นนักทำแผนที่มือฉมัง",
|
| 71 |
"Sanji": "ซันจิ กุ๊กแห่งกลุ่มโจรสลัดหมวกฟาง และเป็นสุดยอดนักสู้ที่ใช้เท้าในการต่อสู้",
|
| 72 |
"Chopper": "โทนี่ โทนี่ ช็อปเปอร์ หมอประจำเรือ ผู้มีใจรักเพื่อนและอ่อนไหวที่สุดในกลุ่ม",
|
| 73 |
"Robin": "นิโค โรบิน นักโบราณคดี ผู้เดียวที่อ่านโพเนกลีฟได้",
|
|
|
|
| 82 |
"Rayleigh": "ซิลเวอร์ส เรย์ลี่ อดีตมือขวาของราชาโจรสลัด โกลด์ ดี. โรเจอร์",
|
| 83 |
}
|
| 84 |
|
| 85 |
+
def generate_typhoon_response(character_name, confidence):
|
|
|
|
| 86 |
"""
|
| 87 |
+
ฟังก์ชัน LLM ที่ใช้ Local Inference ภายใน Space
|
| 88 |
"""
|
| 89 |
+
if llm_model is None:
|
| 90 |
+
return (f"❌ LLM ไม่พร้อมใช้งาน: ตัวละครคือ **{character_name}** "
|
| 91 |
+
f"[ความมั่นใจ: **{confidence*100:.2f}%**]")
|
| 92 |
+
|
| 93 |
info = CHARACTER_INFO.get(character_name, "ตัวละครวันพีซ")
|
| 94 |
|
| 95 |
# 1. Build a clear, instructional prompt for the LLM
|
|
|
|
| 99 |
f"กรุณาสร้างข้อความตอบกลับที่เป็นมิตรและเป็นภาษาไทย โดยขึ้นต้นด้วย 'ยืนยันผลการทำนาย!' "
|
| 100 |
f"และรวมข้อมูลทั้งหมดนี้เข้าด้วยกันในประโยคเดียวโดยใช้ Markdown bold สำหรับชื่อตัวละครและความมั่นใจ (XX.XX%)."
|
| 101 |
)
|
| 102 |
+
|
| 103 |
+
# 2. Generate text using the local LLM
|
| 104 |
+
messages = [{"role": "user", "content": prompt}]
|
| 105 |
+
input_ids = llm_tokenizer.apply_chat_template(
|
| 106 |
+
messages, add_generation_prompt=True, return_tensors="pt"
|
| 107 |
+
).to(device)
|
| 108 |
+
|
| 109 |
+
output_ids = llm_model.generate(
|
| 110 |
+
input_ids,
|
| 111 |
+
max_new_tokens=100,
|
| 112 |
+
temperature=0.7,
|
| 113 |
+
do_sample=True,
|
| 114 |
+
pad_token_id=llm_tokenizer.eos_token_id,
|
| 115 |
+
)
|
| 116 |
|
| 117 |
+
# 3. Decode response
|
| 118 |
+
response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 119 |
+
# Remove the input prompt from the response
|
| 120 |
+
response_text = response.split(prompt)[-1].strip()
|
| 121 |
+
|
| 122 |
+
return response_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
|
|
|
| 124 |
|
| 125 |
+
# 3. ONNX INFERENCE FUNCTION
|
| 126 |
# ----------------------------------------------------
|
| 127 |
+
# เราจะใช้ @spaces.GPU ตรงนี้เพื่อให้ LLM (ซึ่งอยู่ในฟังก์ชันที่ถูกเรียก) รันบน GPU ด้วย
|
| 128 |
+
@spaces.GPU # ใช้ GPU สำหรับการรัน LLM
|
| 129 |
def predict_one_piece_character(pil_image):
|
| 130 |
if pil_image is None or sess is None:
|
| 131 |
return "⚠️ โมเดลไม่พร้อมใช้งาน กรุณาตรวจสอบไฟล์ ONNX และการตั้งค่า"
|
| 132 |
|
| 133 |
try:
|
| 134 |
+
# 3.1 Preprocessing (ConvNeXt standard input)
|
| 135 |
inputs = processor(images=pil_image, return_tensors="np")
|
| 136 |
onnx_input = inputs['pixel_values'].astype(np.float32)
|
| 137 |
|
| 138 |
+
# 3.2 Run Inference (ConvNeXt ONNX)
|
| 139 |
onnx_predictions = sess.run([onnx_output_name], {onnx_input_name: onnx_input})
|
| 140 |
logits = onnx_predictions[0].squeeze()
|
| 141 |
|
| 142 |
+
# 3.3 Post-processing (Softmax and Argmax)
|
| 143 |
probabilities = softmax(logits)
|
| 144 |
predicted_index = np.argmax(probabilities)
|
| 145 |
predicted_character = CHARACTER_LABELS[predicted_index]
|
| 146 |
confidence = probabilities[predicted_index].item()
|
| 147 |
|
| 148 |
+
# 3.4 LLM Integration (Local Generation)
|
| 149 |
+
final_response = generate_typhoon_response(predicted_character, confidence)
|
| 150 |
|
| 151 |
return final_response
|
| 152 |
|
|
|
|
| 154 |
print(f"RUNTIME ERROR: {e}")
|
| 155 |
return f"เกิดข้อผิดพลาดในการทำนาย: {e}"
|
| 156 |
|
| 157 |
+
# 4. GRADIO INTERFACE
|
| 158 |
# ----------------------------------------------------
|
| 159 |
interface = gr.Interface(
|
| 160 |
fn=predict_one_piece_character,
|
| 161 |
inputs=gr.Image(type="pil", label="อัปโหลดรูปภาพตัวละครวันพีซ"),
|
| 162 |
+
outputs=gr.Textbox(label="ผลการทำนายชื่อตัวละคร (Typhoon 2.5 Local)"),
|
| 163 |
+
title="🏴☠️ One Piece Classifier (ConvNeXt ONNX + Typhoon 2.5 Local)",
|
| 164 |
+
description="แอปพลิเคชันจำแนกตัวละครวันพีซโดยรัน LLM ภายใน Space"
|
| 165 |
)
|
| 166 |
|
| 167 |
if __name__ == "__main__":
|