Spaces:
Sleeping
Sleeping
| """ | |
| DAM Model Classes for Demo | |
| Simplified versions of DAM inference for Hugging Face Space deployment | |
| """ | |
| import os | |
| import torch | |
| import time | |
| from PIL import Image | |
| from collections import defaultdict | |
| from typing import Dict, Tuple, Optional | |
| from transformers import AutoModel | |
| # Simplified utility functions | |
| def resize_keep_aspect(img: Image.Image, max_size: int = 1024) -> Image.Image: | |
| """Resize image while keeping aspect ratio.""" | |
| W, H = img.size | |
| if max(W, H) <= max_size: | |
| return img | |
| if W > H: | |
| new_W, new_H = max_size, int(H * max_size / W) | |
| else: | |
| new_W, new_H = int(W * max_size / H), max_size | |
| return img.resize((new_W, new_H), Image.LANCZOS) | |
| def create_full_image_mask(width: int, height: int) -> Image.Image: | |
| """Create a full white mask for the entire image.""" | |
| return Image.new("L", (width, height), 255) | |
| def get_windows(width: int, height: int, window_size: int, stride: int): | |
| """Generate sliding window coordinates.""" | |
| windows = [] | |
| for y in range(0, height - window_size + 1, stride): | |
| for x in range(0, width - window_size + 1, stride): | |
| windows.append((x, y, min(x + window_size, width), min(y + window_size, height))) | |
| # Add remaining edge windows | |
| if width % stride != 0: | |
| for y in range(0, height - window_size + 1, stride): | |
| windows.append((width - window_size, y, width, min(y + window_size, height))) | |
| if height % stride != 0: | |
| for x in range(0, width - window_size + 1, stride): | |
| windows.append((x, height - window_size, min(x + window_size, width), height)) | |
| return windows | |
| def aggregate_votes(votes: Dict[str, float]) -> str: | |
| """Aggregate votes and return the answer with highest weight.""" | |
| if not votes: | |
| return "" | |
| return max(votes.items(), key=lambda x: x[1])[0] | |
| class DAMOriginal: | |
| """Original DAM model using full image.""" | |
| def __init__(self, device: str = "auto"): | |
| if device == "auto": | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) | |
| print(f"Loading DAM model on {self.device}...") | |
| self.dam_model = AutoModel.from_pretrained( | |
| "nvidia/DAM-3B-Self-Contained", | |
| trust_remote_code=True, | |
| ).to(self.device) | |
| self.dam = self.dam_model.init_dam(conv_mode="v1", prompt_mode="full+focal_crop") | |
| print("DAM Original model loaded successfully!") | |
| def predict(self, img: Image.Image, question: str, max_new_tokens: int = 100) -> Tuple[str, float]: | |
| """ | |
| Generate prediction for the question using full image. | |
| Returns: | |
| Tuple of (answer, inference_time) | |
| """ | |
| # Resize image | |
| img = resize_keep_aspect(img, 1024) | |
| W, H = img.size | |
| # Create full image mask | |
| mask = create_full_image_mask(W, H) | |
| # Format prompt | |
| prompt = ( | |
| "<image>\n" | |
| "Answer each question concisely in a single word or short phrase, " | |
| "without any lengthy descriptions or explanations.\n" | |
| "Rely only on information that is clearly visible in the provided image.\n" | |
| "If the answer cannot be determined from the image, respond with \"unanswerable\".\n" | |
| f"Question: {question}\nAnswer:" | |
| ) | |
| # Inference parameters | |
| params = { | |
| "streaming": False, | |
| "temperature": 1e-7, | |
| "top_p": 0.5, | |
| "num_beams": 1, | |
| "max_new_tokens": max_new_tokens | |
| } | |
| start_time = time.time() | |
| try: | |
| tokens = self.dam.get_description(img, mask, prompt, **params) | |
| inference_time = time.time() - start_time | |
| if isinstance(tokens, str): | |
| answer = tokens.strip() | |
| else: | |
| answer = "".join(tokens).strip() | |
| return answer, inference_time | |
| except Exception as e: | |
| inference_time = time.time() - start_time | |
| print(f"Error in DAM Original prediction: {e}") | |
| return f"Error: {str(e)}", inference_time | |
| class DAMSlidingWindow: | |
| """DAM model with sliding window approach.""" | |
| def __init__(self, device: str = "auto", window_size: int = 512, stride: int = 256): | |
| if device == "auto": | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) | |
| self.window_size = window_size | |
| self.stride = stride | |
| print(f"Loading DAM model on {self.device}...") | |
| self.dam_model = AutoModel.from_pretrained( | |
| "nvidia/DAM-3B-Self-Contained", | |
| trust_remote_code=True, | |
| ).to(self.device) | |
| self.dam = self.dam_model.init_dam(conv_mode="v1", prompt_mode="full+focal_crop") | |
| print(f"DAM Sliding Window model loaded successfully! (window_size={window_size}, stride={stride})") | |
| def predict(self, img: Image.Image, question: str, max_new_tokens: int = 100, | |
| unanswerable_weight: float = 1.0) -> Tuple[str, float, Dict]: | |
| """ | |
| Generate prediction using sliding window approach with voting. | |
| Returns: | |
| Tuple of (answer, inference_time, voting_details) | |
| """ | |
| # Resize image | |
| img = resize_keep_aspect(img, 1024) | |
| W, H = img.size | |
| # Format prompt | |
| prompt = ( | |
| "<image>\n" | |
| "Answer each question concisely in a single word or short phrase, " | |
| "without any lengthy descriptions or explanations.\n" | |
| "Rely only on information that is clearly visible in the provided image.\n" | |
| "If the answer cannot be determined from the image, respond with \"unanswerable\".\n" | |
| f"Question: {question}\nAnswer:" | |
| ) | |
| # Inference parameters | |
| params = { | |
| "streaming": False, | |
| "temperature": 1e-7, | |
| "top_p": 0.5, | |
| "num_beams": 1, | |
| "max_new_tokens": max_new_tokens | |
| } | |
| start_time = time.time() | |
| votes = defaultdict(float) | |
| voting_details = {"full_image": None, "windows": []} | |
| try: | |
| # Full image vote | |
| mask_full = create_full_image_mask(W, H) | |
| ans_full = self.dam.get_description(img, mask_full, prompt, **params) | |
| if isinstance(ans_full, str): | |
| ans_full = ans_full.strip() | |
| else: | |
| ans_full = "".join(ans_full).strip() | |
| if ans_full: | |
| weight = 1.0 | |
| if ans_full.lower() == "unanswerable": | |
| weight *= unanswerable_weight | |
| votes[ans_full] += weight | |
| voting_details["full_image"] = {"answer": ans_full, "weight": weight} | |
| # Sliding window votes | |
| windows = get_windows(W, H, self.window_size, self.stride) | |
| for i, (x0, y0, x1, y1) in enumerate(windows): | |
| crop = img.crop((x0, y0, x1, y1)) | |
| mask_crop = Image.new("L", (x1-x0, y1-y0), 255) | |
| ans = self.dam.get_description(crop, mask_crop, prompt, **params) | |
| if isinstance(ans, str): | |
| ans = ans.strip() | |
| else: | |
| ans = "".join(ans).strip() | |
| if ans: | |
| weight = ((x1-x0) * (y1-y0)) / (W * H) | |
| if ans.lower() == "unanswerable": | |
| weight *= unanswerable_weight | |
| votes[ans] += weight | |
| voting_details["windows"].append({ | |
| "window_id": i, | |
| "coords": (x0, y0, x1, y1), | |
| "answer": ans, | |
| "weight": weight | |
| }) | |
| # Aggregate votes | |
| prediction = aggregate_votes(votes) | |
| if not prediction: | |
| prediction = ans_full if 'ans_full' in locals() else "No answer" | |
| inference_time = time.time() - start_time | |
| # Add vote summary to details | |
| voting_details["vote_summary"] = dict(votes) | |
| voting_details["final_answer"] = prediction | |
| voting_details["total_windows"] = len(windows) | |
| return prediction, inference_time, voting_details | |
| except Exception as e: | |
| inference_time = time.time() - start_time | |
| print(f"Error in DAM Sliding Window prediction: {e}") | |
| return f"Error: {str(e)}", inference_time, {"error": str(e)} | |
| # Global model instances (lazy loading) | |
| _dam_original = None | |
| _dam_sliding = None | |
| def get_dam_original(device: str = "auto"): | |
| """Get or create DAM Original model instance.""" | |
| global _dam_original | |
| if _dam_original is None: | |
| _dam_original = DAMOriginal(device) | |
| return _dam_original | |
| def get_dam_sliding(device: str = "auto", window_size: int = 512, stride: int = 256): | |
| """Get or create DAM Sliding Window model instance.""" | |
| global _dam_sliding | |
| if _dam_sliding is None: | |
| _dam_sliding = DAMSlidingWindow(device, window_size, stride) | |
| return _dam_sliding | |