saa231 commited on
Commit
dd0f882
·
verified ·
1 Parent(s): afa4a63

Upload project_model.py

Browse files
Files changed (1) hide show
  1. project_model.py +101 -0
project_model.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """project_model.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1oopkA5yIlfizFuhXOPmTK7MUNh3Qasa3
8
+ """
9
+
10
+ # project_module.py
11
+
12
+ import torch, cv2, time, os
13
+ import numpy as np
14
+ from PIL import Image
15
+ from ultralytics import YOLO
16
+ from transformers import pipeline, DPTFeatureExtractor, DPTForDepthEstimation
17
+ from TTS.api import TTS
18
+
19
+ # Load models
20
+
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ yolo_model = YOLO("yolov8n.pt")
23
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).eval()
24
+ depth_feat = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
25
+ whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-small", device=0 if torch.cuda.is_available() else -1)
26
+ gemma_pipe = pipeline(
27
+ "image-text-to-text",
28
+ model="google/gemma-3-4b-it",
29
+ device=0 if torch.cuda.is_available() else -1,
30
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
31
+ )
32
+ tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
33
+
34
+ # Function to process image and audio
35
+ def process_inputs(image: Image.Image, audio_path: str):
36
+ # Convert PIL image to OpenCV format
37
+ rgb_image = np.array(image)
38
+ cv2_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
39
+ pil_image = image
40
+
41
+ # YOLO Detection
42
+ yolo_results = yolo_model.predict(cv2_image)[0]
43
+ boxes = yolo_results.boxes
44
+ class_names = yolo_model.names
45
+
46
+ # MiDaS Depth
47
+ depth_inputs = depth_feat(images=pil_image, return_tensors="pt").to(device)
48
+ with torch.no_grad():
49
+ depth_output = depth_model(**depth_inputs)
50
+ depth_map = depth_output.predicted_depth.squeeze().cpu().numpy()
51
+ depth_map_resized = cv2.resize(depth_map, (rgb_image.shape[1], rgb_image.shape[0]))
52
+
53
+ # Visual Context
54
+ shared_visual_context = []
55
+ for box in boxes:
56
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
57
+ label = class_names[int(box.cls[0])]
58
+ conf = float(box.conf[0])
59
+ depth_crop = depth_map_resized[y1:y2, x1:x2]
60
+ avg_depth = float(depth_crop.mean()) if depth_crop.size > 0 else None
61
+ x_center = (x1 + x2) / 2
62
+ pos = "left" if x_center < rgb_image.shape[1] / 3 else "right" if x_center > 2 * rgb_image.shape[1] / 3 else "center"
63
+ shared_visual_context.append({
64
+ "label": label,
65
+ "confidence": conf,
66
+ "avg_depth": avg_depth,
67
+ "position": pos
68
+ })
69
+
70
+ # Build Context Text
71
+ def build_context_description(context):
72
+ descriptions = []
73
+ for obj in context:
74
+ d = f"{obj['avg_depth']:.1f} units" if obj["avg_depth"] else "unknown"
75
+ s = obj.get("position", "unknown")
76
+ c = obj.get("confidence", 0.0)
77
+ descriptions.append(f"a {obj['label']} ({c:.2f} confidence) is at {d} on the {s}")
78
+ return "In the image, " + ", ".join(descriptions) + "."
79
+
80
+ context_text = build_context_description(shared_visual_context)
81
+
82
+ # Transcribe audio
83
+ transcription = whisper_pipe(audio_path)["text"]
84
+ vqa_prompt = context_text + " " + transcription
85
+
86
+ # GEMMA answer
87
+ messages = [{
88
+ "role": "user",
89
+ "content": [
90
+ {"type": "image", "image": pil_image},
91
+ {"type": "text", "text": vqa_prompt}
92
+ ]
93
+ }]
94
+ gemma_output = gemma_pipe(text=messages, max_new_tokens=200)
95
+ answer = gemma_output[0]["generated_text"][-1]["content"]
96
+
97
+ # Generate speech
98
+ output_audio_path = "response.wav"
99
+ tts.tts_to_file(text=answer, file_path=output_audio_path)
100
+
101
+ return answer, output_audio_path