Spaces:
Runtime error
Runtime error
Aastha
commited on
Commit
Β·
002cef1
1
Parent(s):
aa73c8d
add configurable device support
Browse files
app.py
CHANGED
|
@@ -19,6 +19,7 @@ from super_gradients.training import models
|
|
| 19 |
|
| 20 |
class Kosmos2:
|
| 21 |
def __init__(self):
|
|
|
|
| 22 |
self.colors = [
|
| 23 |
(0, 255, 0),
|
| 24 |
(0, 0, 255),
|
|
@@ -43,7 +44,7 @@ class Kosmos2:
|
|
| 43 |
}
|
| 44 |
|
| 45 |
self.ckpt = "ydshieh/kosmos-2-patch14-224"
|
| 46 |
-
self.model = AutoModelForVision2Seq.from_pretrained(self.ckpt, trust_remote_code=True).to(
|
| 47 |
self.processor = AutoProcessor.from_pretrained(self.ckpt, trust_remote_code=True)
|
| 48 |
|
| 49 |
def is_overlapping(self, rect1, rect2):
|
|
@@ -191,11 +192,11 @@ class Kosmos2:
|
|
| 191 |
inputs = self.processor(text=text_input, images=image_input, return_tensors="pt")
|
| 192 |
|
| 193 |
generated_ids = self.model.generate(
|
| 194 |
-
pixel_values=inputs["pixel_values"].to(
|
| 195 |
-
input_ids=inputs["input_ids"][:, :-1].to(
|
| 196 |
-
attention_mask=inputs["attention_mask"][:, :-1].to(
|
| 197 |
img_features=None,
|
| 198 |
-
img_attn_mask=inputs["img_attn_mask"][:, :-1].to(
|
| 199 |
use_cache=True,
|
| 200 |
max_new_tokens=128,
|
| 201 |
)
|
|
|
|
| 19 |
|
| 20 |
class Kosmos2:
|
| 21 |
def __init__(self):
|
| 22 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
self.colors = [
|
| 24 |
(0, 255, 0),
|
| 25 |
(0, 0, 255),
|
|
|
|
| 44 |
}
|
| 45 |
|
| 46 |
self.ckpt = "ydshieh/kosmos-2-patch14-224"
|
| 47 |
+
self.model = AutoModelForVision2Seq.from_pretrained(self.ckpt, trust_remote_code=True).to(self.device)
|
| 48 |
self.processor = AutoProcessor.from_pretrained(self.ckpt, trust_remote_code=True)
|
| 49 |
|
| 50 |
def is_overlapping(self, rect1, rect2):
|
|
|
|
| 192 |
inputs = self.processor(text=text_input, images=image_input, return_tensors="pt")
|
| 193 |
|
| 194 |
generated_ids = self.model.generate(
|
| 195 |
+
pixel_values=inputs["pixel_values"].to(self.device),
|
| 196 |
+
input_ids=inputs["input_ids"][:, :-1].to(self.device),
|
| 197 |
+
attention_mask=inputs["attention_mask"][:, :-1].to(self.device),
|
| 198 |
img_features=None,
|
| 199 |
+
img_attn_mask=inputs["img_attn_mask"][:, :-1].to(self.device),
|
| 200 |
use_cache=True,
|
| 201 |
max_new_tokens=128,
|
| 202 |
)
|