Spaces:
Sleeping
Sleeping
| # app.py | |
| import streamlit as st | |
| from PIL import Image | |
| import torch | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration, pipeline | |
| import mediapipe as mp | |
| import cv2 | |
| import numpy as np | |
| import tempfile | |
| # ------------------------- | |
| # 1. BLIP-2 ๋ชจ๋ธ ๋ก๋ฉ | |
| # ------------------------- | |
| st.title("๐ท AI ํฌ์ฆ ์ดฌ์ ๊ฐ์ด๋") | |
| st.caption("์ ๋ก๋ํ ์ด๋ฏธ์ง๋ฅผ ๋ถ์ํ๊ณ ๋ฐ๋ผํ๊ธฐ ์ํ ํฌ์ฆ ์๋ด ๋ฐ ์ดฌ์ ๊ฐ์ด๋๋ฅผ ์์ฑํฉ๋๋ค") | |
| def load_blip2(): | |
| processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
| model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto", torch_dtype=torch.float16) | |
| return processor, model | |
| processor, model = load_blip2() | |
| # ------------------------- | |
| # 2. DistilGPT2 ํ ์คํธ ์์ฑ๊ธฐ ๋ก๋ฉ | |
| # ------------------------- | |
| gpt_generator = pipeline("text-generation", model="distilgpt2") | |
| # ------------------------- | |
| # 3. ์ด๋ฏธ์ง ์ ๋ก๋ & ๋ถ์ | |
| # ------------------------- | |
| uploaded_file = st.file_uploader("์ด๋ฏธ์ง๋ฅผ ์ ๋ก๋ํ์ธ์", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| st.image(image, caption="์ ๋ก๋ํ ์ด๋ฏธ์ง", use_column_width=True) | |
| # BLIP-2๋ก ์ค๋ช ์์ฑ | |
| inputs = processor(image, return_tensors="pt").to(model.device, torch.float16) | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs) | |
| caption = processor.decode(generated_ids[0], skip_special_tokens=True) | |
| st.subheader("๐ง ์ด๋ฏธ์ง ์ค๋ช ") | |
| st.info(caption) | |
| # ์ง๋ฌธ ๊ธฐ๋ฅ | |
| user_q = st.text_input("์ด๋ฏธ์ง์ ๋ํด ์ง๋ฌธํด๋ณด์ธ์ (์: ์ด ์ฌ๋์ ์ท ์คํ์ผ์?)") | |
| if user_q: | |
| prompt = f"์ง๋ฌธ: {user_q}\n์ด๋ฏธ์ง ์ค๋ช : {caption}\n๋ต๋ณ:" | |
| ans = gpt_generator(prompt, max_length=100, do_sample=True)[0]["generated_text"].split("๋ต๋ณ:")[-1] | |
| st.success(ans) | |
| # ------------------------- | |
| # 4. ํฌ์ฆ ์ค์ผ๋ ํค ์๊ฐํ | |
| # ------------------------- | |
| st.subheader("๐บ ํฌ์ฆ ์์๋ผ์ธ") | |
| mp_pose = mp.solutions.pose | |
| pose = mp_pose.Pose(static_image_mode=True) | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_file: | |
| image.save(tmp_file.name) | |
| img_bgr = cv2.imread(tmp_file.name) | |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
| results = pose.process(img_rgb) | |
| if results.pose_landmarks: | |
| annotated = img_rgb.copy() | |
| mp_drawing = mp.solutions.drawing_utils | |
| mp_drawing.draw_landmarks( | |
| annotated, | |
| results.pose_landmarks, | |
| mp_pose.POSE_CONNECTIONS, | |
| landmark_drawing_spec=mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=2, circle_radius=3), | |
| connection_drawing_spec=mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=2), | |
| ) | |
| st.image(annotated, caption="ํฌ์ฆ ์๊ฐํ", use_column_width=True) | |
| else: | |
| st.warning("ํฌ์ฆ๋ฅผ ์ถ์ถํ์ง ๋ชปํ์ต๋๋ค.") | |
| # ------------------------- | |
| # 5. ์ฐ์ถ ๊ฐ์ด๋ ์์ฑ | |
| # ------------------------- | |
| st.subheader("๐ฏ ์ดฌ์ ๊ฐ์ด๋") | |
| prompt = f"์ด๋ฏธ์ง ์ค๋ช : {caption}\n์ด ์ด๋ฏธ์ง๋ฅผ ์ฐธ๊ณ ํด ๋น์ทํ ์ฌ์ง์ ์ฐ์ผ๋ ค๋ฉด ์ด๋ป๊ฒ ํฌ์ฆ๋ฅผ ์ก์์ผ ํ ๊น?" | |
| instruction = gpt_generator(prompt, max_length=100, do_sample=True)[0]['generated_text'].split("์ฐ์ผ๋ ค๋ฉด")[-1] | |
| st.markdown(f"**ํฌ์ฆ ๊ฐ์ด๋:** {instruction.strip()}") | |