pose-like / src /app.py
huggysynuo's picture
Update src/app.py
6f9d1a5 verified
# app.py
import streamlit as st
from PIL import Image
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration, pipeline
import mediapipe as mp
import cv2
import numpy as np
import tempfile
# -------------------------
# 1. BLIP-2 ๋ชจ๋ธ ๋กœ๋”ฉ
# -------------------------
st.title("๐Ÿ“ท AI ํฌ์ฆˆ ์ดฌ์˜ ๊ฐ€์ด๋“œ")
st.caption("์—…๋กœ๋“œํ•œ ์ด๋ฏธ์ง€๋ฅผ ๋ถ„์„ํ•˜๊ณ  ๋”ฐ๋ผํ•˜๊ธฐ ์œ„ํ•œ ํฌ์ฆˆ ์•ˆ๋‚ด ๋ฐ ์ดฌ์˜ ๊ฐ€์ด๋“œ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค")
@st.cache_resource
def load_blip2():
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto", torch_dtype=torch.float16)
return processor, model
processor, model = load_blip2()
# -------------------------
# 2. DistilGPT2 ํ…์ŠคํŠธ ์ƒ์„ฑ๊ธฐ ๋กœ๋”ฉ
# -------------------------
gpt_generator = pipeline("text-generation", model="distilgpt2")
# -------------------------
# 3. ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ & ๋ถ„์„
# -------------------------
uploaded_file = st.file_uploader("์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="์—…๋กœ๋“œํ•œ ์ด๋ฏธ์ง€", use_column_width=True)
# BLIP-2๋กœ ์„ค๋ช… ์ƒ์„ฑ
inputs = processor(image, return_tensors="pt").to(model.device, torch.float16)
with torch.no_grad():
generated_ids = model.generate(**inputs)
caption = processor.decode(generated_ids[0], skip_special_tokens=True)
st.subheader("๐Ÿง  ์ด๋ฏธ์ง€ ์„ค๋ช…")
st.info(caption)
# ์งˆ๋ฌธ ๊ธฐ๋Šฅ
user_q = st.text_input("์ด๋ฏธ์ง€์— ๋Œ€ํ•ด ์งˆ๋ฌธํ•ด๋ณด์„ธ์š” (์˜ˆ: ์ด ์‚ฌ๋žŒ์˜ ์˜ท ์Šคํƒ€์ผ์€?)")
if user_q:
prompt = f"์งˆ๋ฌธ: {user_q}\n์ด๋ฏธ์ง€ ์„ค๋ช…: {caption}\n๋‹ต๋ณ€:"
ans = gpt_generator(prompt, max_length=100, do_sample=True)[0]["generated_text"].split("๋‹ต๋ณ€:")[-1]
st.success(ans)
# -------------------------
# 4. ํฌ์ฆˆ ์Šค์ผˆ๋ ˆํ†ค ์‹œ๊ฐํ™”
# -------------------------
st.subheader("๐Ÿ•บ ํฌ์ฆˆ ์•„์›ƒ๋ผ์ธ")
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=True)
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
image.save(tmp_file.name)
img_bgr = cv2.imread(tmp_file.name)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
results = pose.process(img_rgb)
if results.pose_landmarks:
annotated = img_rgb.copy()
mp_drawing = mp.solutions.drawing_utils
mp_drawing.draw_landmarks(
annotated,
results.pose_landmarks,
mp_pose.POSE_CONNECTIONS,
landmark_drawing_spec=mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=3),
connection_drawing_spec=mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=2),
)
st.image(annotated, caption="ํฌ์ฆˆ ์‹œ๊ฐํ™”", use_column_width=True)
else:
st.warning("ํฌ์ฆˆ๋ฅผ ์ถ”์ถœํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.")
# -------------------------
# 5. ์—ฐ์ถœ ๊ฐ€์ด๋“œ ์ƒ์„ฑ
# -------------------------
st.subheader("๐ŸŽฏ ์ดฌ์˜ ๊ฐ€์ด๋“œ")
prompt = f"์ด๋ฏธ์ง€ ์„ค๋ช…: {caption}\n์ด ์ด๋ฏธ์ง€๋ฅผ ์ฐธ๊ณ ํ•ด ๋น„์Šทํ•œ ์‚ฌ์ง„์„ ์ฐ์œผ๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํฌ์ฆˆ๋ฅผ ์žก์•„์•ผ ํ• ๊นŒ?"
instruction = gpt_generator(prompt, max_length=100, do_sample=True)[0]['generated_text'].split("์ฐ์œผ๋ ค๋ฉด")[-1]
st.markdown(f"**ํฌ์ฆˆ ๊ฐ€์ด๋“œ:** {instruction.strip()}")