|
|
""" |
|
|
SmolVLM Webcam Auto Inference (Fine-tuned) |
|
|
3์ด๋ง๋ค ์๋์ผ๋ก inference ์ํ |
|
|
Fine-tuned on Hair classification & description dataset |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoProcessor, AutoModelForImageTextToText |
|
|
from peft import PeftModel |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from datetime import datetime |
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
BASE_MODEL_ID = "HuggingFaceTB/SmolVLM-256M-Instruct" |
|
|
FINETUNED_MODEL_PATH = "/root/crying_cv_vlm/checkpoint-105" |
|
|
INFERENCE_INTERVAL = 3 |
|
|
|
|
|
print(f"๐ง Device: {DEVICE}") |
|
|
print(f"๐ Fine-tuned Model: {FINETUNED_MODEL_PATH}") |
|
|
print("Loading model...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForImageTextToText |
|
|
from peft import PeftModel |
|
|
|
|
|
print("1๏ธโฃ Loading base model...") |
|
|
model = AutoModelForImageTextToText.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32, |
|
|
device_map="auto", |
|
|
attn_implementation="eager" |
|
|
) |
|
|
|
|
|
print("2๏ธโฃ Loading fine-tuned adapter...") |
|
|
model = PeftModel.from_pretrained( |
|
|
model, |
|
|
FINETUNED_MODEL_PATH, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
print("3๏ธโฃ Merging adapter...") |
|
|
model = model.merge_and_unload() |
|
|
model.eval() |
|
|
|
|
|
print("4๏ธโฃ Loading processor...") |
|
|
processor = AutoProcessor.from_pretrained(FINETUNED_MODEL_PATH) |
|
|
|
|
|
print("โ
Model loaded!") |
|
|
if torch.cuda.is_available(): |
|
|
print(f"๐พ VRAM: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB") |
|
|
|
|
|
|
|
|
def inference(image, question): |
|
|
"""์ด๋ฏธ์ง์ ์ง๋ฌธ์ ๋ฐ์ inference ์ํ""" |
|
|
|
|
|
if image is None: |
|
|
return "โ ๏ธ ์น์บ ์์ ์ด๋ฏธ์ง๋ฅผ ์บก์ฒํด์ฃผ์ธ์.", "๋๊ธฐ ์ค" |
|
|
|
|
|
if not question or question.strip() == "": |
|
|
question = "Describe this image in detail." |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image).convert('RGB') |
|
|
elif not isinstance(image, Image.Image): |
|
|
return "โ ์๋ชป๋ ์ด๋ฏธ์ง ํ์", "์๋ฌ" |
|
|
elif image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
|
|
|
messages = [{ |
|
|
"role": "user", |
|
|
"content": [{"type": "image"}, {"type": "text", "text": question}] |
|
|
}] |
|
|
|
|
|
|
|
|
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
|
inputs = processor(text=prompt, images=[image], return_tensors="pt").to(DEVICE) |
|
|
|
|
|
|
|
|
input_len = inputs["input_ids"].shape[-1] |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=100, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9 |
|
|
) |
|
|
|
|
|
|
|
|
generated_ids = generated_ids[0][input_len:] |
|
|
response = processor.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
status = f"โ
{timestamp}" |
|
|
|
|
|
return response if response else "(๋น ์๋ต)", status |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = traceback.format_exc() |
|
|
return f"โ ์๋ฌ: {str(e)}\n\n{error_msg}", "์๋ฌ ๋ฐ์" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="SmolVLM Auto Inference") as demo: |
|
|
gr.Markdown(""" |
|
|
# ๐ฅ SmolVLM ์น์บ ์๋ ์ถ๋ก (Fine-tuned) |
|
|
|
|
|
**3์ด๋ง๋ค ์๋์ผ๋ก ์ถ๋ก ์ ์ํํฉ๋๋ค** |
|
|
|
|
|
### ๋ชจ๋ธ ์ ๋ณด: |
|
|
- **Base Model**: HuggingFaceTB/SmolVLM-256M-Instruct |
|
|
- **Fine-tuned on**: Hair classification & description dataset |
|
|
- **Training**: 5 epochs, Final loss: 1.1350 |
|
|
|
|
|
### ์ฌ์ฉ ๋ฐฉ๋ฒ: |
|
|
1. ์น์บ ํ์ฉ ๋ฐ ์ด๋ฏธ์ง ์บก์ฒ |
|
|
2. ์ง๋ฌธ ์
๋ ฅ |
|
|
3. "๐ ์๋ ์ถ๋ก ์์" ๋ฒํผ ํด๋ฆญ |
|
|
4. 3์ด๋ง๋ค ์๋์ผ๋ก ์ถ๋ก ๋ฉ๋๋ค |
|
|
5. "โธ๏ธ ์ค์ง" ๋ฒํผ์ผ๋ก ๋ฉ์ถ ์ ์์ต๋๋ค |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
webcam = gr.Image( |
|
|
label="๐ท ์น์บ ", |
|
|
type="numpy", |
|
|
sources=["webcam"], |
|
|
streaming=True, |
|
|
height=400 |
|
|
) |
|
|
|
|
|
|
|
|
question = gr.Textbox( |
|
|
label="๐ฌ ์ง๋ฌธ", |
|
|
placeholder="์ด๋ฏธ์ง์ ๋ํด ๋ฌผ์ด๋ณด๊ณ ์ถ์ ๊ฒ์ ์
๋ ฅํ์ธ์", |
|
|
value="Classify the hair length in this image. Possible values: short, mid, long. Output only one word.", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
start_btn = gr.Button("๐ ์๋ ์ถ๋ก ์์", variant="primary", scale=2) |
|
|
stop_btn = gr.Button("โธ๏ธ ์ค์ง", variant="stop", scale=1) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
output = gr.Textbox( |
|
|
label="๐ค ์๋ต", |
|
|
lines=15, |
|
|
max_lines=20 |
|
|
) |
|
|
|
|
|
|
|
|
status = gr.Textbox( |
|
|
label="๐ ์ํ", |
|
|
value="๋๊ธฐ ์ค", |
|
|
lines=1 |
|
|
) |
|
|
|
|
|
|
|
|
auto_status = gr.Textbox( |
|
|
label="๐ ์๋ ์ถ๋ก ์ํ", |
|
|
value="๋ฉ์ถค", |
|
|
lines=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### ๐ก ์์ ์ง๋ฌธ:") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Classify the hair length in this image. Possible values: short, mid, long. Output only one word."], |
|
|
["Describe the person's hair style, color, and texture in detail."], |
|
|
["What is the hair length? Answer in one word: short, mid, or long."], |
|
|
["Describe what you see in this image."], |
|
|
["์ด ์ฌ๋์ ๋จธ๋ฆฌ ๊ธธ์ด๋ฅผ ๋ถ๋ฅํ์ธ์. ๊ฐ๋ฅํ ๊ฐ: short, mid, long"], |
|
|
], |
|
|
inputs=[question], |
|
|
) |
|
|
|
|
|
|
|
|
is_auto_running = gr.State(value=False) |
|
|
last_inference_time = gr.State(value=0) |
|
|
|
|
|
def start_auto_inference(): |
|
|
"""์๋ ์ถ๋ก ์์""" |
|
|
|
|
|
return True, "โถ๏ธ ์คํ ์ค (3์ด ๊ฐ๊ฒฉ)", gr.Timer(value=0.5, active=True), time.time() - INFERENCE_INTERVAL |
|
|
|
|
|
def stop_auto_inference(): |
|
|
"""์๋ ์ถ๋ก ์ค์ง""" |
|
|
return False, "โธ๏ธ ๋ฉ์ถค", gr.Timer(value=0.5, active=False) |
|
|
|
|
|
def auto_inference_loop(image, question_text, is_running, last_time): |
|
|
"""์๋ ์ถ๋ก ๋ฃจํ (3์ด๋ง๋ค ์คํ)""" |
|
|
if not is_running: |
|
|
return gr.update(), gr.update(), last_time |
|
|
|
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
if image is None: |
|
|
return gr.update(), "โ ๏ธ ์น์บ ์ด๋ฏธ์ง๋ฅผ ์บก์ฒํด์ฃผ์ธ์", last_time |
|
|
|
|
|
|
|
|
if current_time - last_time >= INFERENCE_INTERVAL: |
|
|
result, status_msg = inference(image, question_text) |
|
|
return result, status_msg, current_time |
|
|
else: |
|
|
|
|
|
remaining = INFERENCE_INTERVAL - (current_time - last_time) |
|
|
return gr.update(), f"โฑ๏ธ ๋ค์ ์ถ๋ก ๊น์ง {remaining:.1f}์ด", last_time |
|
|
|
|
|
|
|
|
timer = gr.Timer(value=0.5, active=False) |
|
|
|
|
|
|
|
|
start_btn.click( |
|
|
fn=start_auto_inference, |
|
|
inputs=[], |
|
|
outputs=[is_auto_running, auto_status, timer, last_inference_time] |
|
|
) |
|
|
|
|
|
|
|
|
stop_btn.click( |
|
|
fn=stop_auto_inference, |
|
|
inputs=[], |
|
|
outputs=[is_auto_running, auto_status, timer] |
|
|
) |
|
|
|
|
|
|
|
|
timer.tick( |
|
|
fn=auto_inference_loop, |
|
|
inputs=[webcam, question, is_auto_running, last_inference_time], |
|
|
outputs=[output, status, last_inference_time] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\n" + "="*70) |
|
|
print("๐ Launching at http://0.0.0.0:7860") |
|
|
print("="*70 + "\n") |
|
|
|
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=8085, |
|
|
share=False, |
|
|
show_error=True |
|
|
) |
|
|
|
|
|
|