suhpau commited on
Commit
c0bb180
·
verified ·
1 Parent(s): f388865

Create 안녕

Browse files
Files changed (1) hide show
  1. 안녕 +75 -0
안녕 ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pip install transformers torch pillow requests
2
+ from transformers import BlipProcessor, BlipForQuestionAnswering
3
+ from PIL import Image
4
+ import torch
5
+ import requests
6
+ from io import BytesIO
7
+
8
+ class VQASystem:
9
+ def __init__(self, model_name="Salesforce/blip-vqa-base"):
10
+ """VQA 모델 초기화"""
11
+ print(f"🔧 VQA 모델 로드 중: {model_name}")
12
+ self.processor = BlipProcessor.from_pretrained(model_name)
13
+ self.model = BlipForQuestionAnswering.from_pretrained(model_name)
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ self.model.to(self.device)
16
+ print("✅ 모델 로드 완료")
17
+
18
+ def load_image(self, image_source):
19
+ """이미지 로드"""
20
+ if image_source.startswith('http'):
21
+ response = requests.get(image_source)
22
+ image = Image.open(BytesIO(response.content)).convert('RGB')
23
+ else:
24
+ image = Image.open(image_source).convert('RGB')
25
+ return image
26
+
27
+ def generate_answer(self, image_path, question):
28
+ """질문에 대한 답변 생성"""
29
+ try:
30
+ raw_image = self.load_image(image_path)
31
+
32
+ # 모델 입력 생성
33
+ inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device)
34
+
35
+ # 답변 생성 (max_new_tokens 조절 가능)
36
+ with torch.no_grad():
37
+ out = self.model.generate(**inputs, max_new_tokens=50)
38
+
39
+ answer = self.processor.decode(out[0], skip_special_tokens=True)
40
+ return answer
41
+ except Exception as e:
42
+ return f"Error: {str(e)}"
43
+
44
+ def batch_qa(self, image_path, questions):
45
+ """여러 질문 일괄 처리"""
46
+ print(f"🖼️ 이미지 분석 중: {image_path}")
47
+ results = {}
48
+ for q in questions:
49
+ ans = self.generate_answer(image_path, q)
50
+ results[q] = ans
51
+ print(f"Q: {q}\nA: {ans}\n")
52
+ return results
53
+
54
+ def main():
55
+ print("="*60)
56
+ print("Project 2: Visual Question Answering System")
57
+ print("="*60)
58
+
59
+ vqa = VQASystem()
60
+
61
+ # 테스트용 이미지
62
+ test_image = "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg"
63
+
64
+ questions = [
65
+ "What animal is in the picture?",
66
+ "What is the dog doing?",
67
+ "What color is the dog?",
68
+ "Is there a cat in the image?"
69
+ ]
70
+
71
+ vqa.batch_qa(test_image, questions)
72
+
73
+ if __name__ == "__main__":
74
+ main()
75
+