cwangrun commited on
Commit
e8e1834
·
verified ·
1 Parent(s): 9acd189

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +29 -47
README.md CHANGED
@@ -1,47 +1,29 @@
1
- import torch
2
- from PIL import Image
3
- from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
4
-
5
- device = "cuda" if torch.cuda.is_available() else "cpu"
6
-
7
- # ===== 加载模型 =====
8
- # model = AutoModel.from_pretrained(
9
- # "StanfordAIMI/CheXficient",
10
- # trust_remote_code=True
11
- # ).to(device)
12
-
13
- model = AutoModel.from_pretrained(
14
- "/mnt/d/torch/CheXficient/hf_model",
15
- trust_remote_code=True
16
- ).to(device)
17
-
18
-
19
- # ===== 加载tokenizer =====
20
- tokenizer = AutoTokenizer.from_pretrained(
21
- "emilyalsentzer/Bio_ClinicalBERT"
22
- )
23
-
24
- # ===== 加载image processor =====
25
- image_processor = AutoImageProcessor.from_pretrained(
26
- "facebook/dinov2-base"
27
- )
28
-
29
- # ===== 准备数据 =====
30
- image = Image.open("xray.jpg").convert("RGB")
31
- text = ["pneumonia", "no acute cardiopulmonary abnormality"]
32
-
33
- image_inputs = image_processor(images=image, return_tensors="pt").to(device)
34
- text_inputs = tokenizer(text, padding=True, return_tensors="pt").to(device)
35
-
36
- # ===== 推理 =====
37
- with torch.no_grad():
38
- outputs = model(
39
- pixel_values=image_inputs["pixel_values"],
40
- input_ids=text_inputs["input_ids"],
41
- attention_mask=text_inputs["attention_mask"],
42
- )
43
-
44
- logits = outputs["logits_per_image"]
45
- probs = logits.softmax(dim=-1)
46
-
47
- print(probs)
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
4
+
5
+ repo_id = "StanfordAIMI/CheXficient"
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ model = AutoModel.from_pretrained(
9
+ repo_id,
10
+ trust_remote_code=True
11
+ ).to(device)
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
14
+ image_processor = AutoImageProcessor.from_pretrained(repo_id, trust_remote_code=True)
15
+
16
+ model.eval()
17
+
18
+ image = Image.open("./CXR/images/5AF3BB6C1BCC83C.png").convert("RGB")
19
+ text = ["Pneumonia", "no Pneumonia"]
20
+
21
+ image_inputs = image_processor(images=image, return_tensors="pt").to(device)
22
+ text_inputs = tokenizer(text, padding=True, return_tensors="pt").to(device)
23
+
24
+ with torch.no_grad():
25
+ outputs = model(
26
+ pixel_values=image_inputs["pixel_values"],
27
+ text_tokens=text_inputs,
28
+ )
29
+ print(outputs)