malusama commited on
Commit
a44eea8
·
verified ·
1 Parent(s): 1fde89f

Load tokenizer and image processor directly from snapshot

Browse files
Files changed (1) hide show
  1. app.py +30 -8
app.py CHANGED
@@ -1,11 +1,13 @@
1
  from functools import lru_cache
 
2
  import json
3
  import os
 
4
 
5
  import torch
6
  from huggingface_hub import snapshot_download
7
  from PIL import Image
8
- from transformers import AutoModel, AutoProcessor
9
 
10
 
11
  os.environ["HF_ENDPOINT"] = "https://huggingface.co"
@@ -22,17 +24,22 @@ def load_components():
22
  repo_id=MODEL_ID,
23
  revision=MODEL_REVISION,
24
  )
 
 
 
25
  model = AutoModel.from_pretrained(
26
  model_dir,
27
  trust_remote_code=True,
28
  )
29
- processor = AutoProcessor.from_pretrained(
30
- model_dir,
31
- trust_remote_code=True,
32
  )
 
 
 
33
  model.to(DEVICE)
34
  model.eval()
35
- return model, processor
36
 
37
 
38
  def parse_labels(text: str):
@@ -52,10 +59,25 @@ def run_demo(image: Image.Image, candidate_text: str):
52
  if not labels:
53
  raise ValueError("Please enter at least one label.")
54
 
55
- model, processor = load_components()
56
  with torch.no_grad():
57
- text_inputs = processor(text=labels, return_tensors="pt")
58
- image_inputs = processor(images=image.convert("RGB"), return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  text_outputs = model(**text_inputs)
61
  image_outputs = model(**image_inputs)
 
1
  from functools import lru_cache
2
+ import importlib
3
  import json
4
  import os
5
+ import sys
6
 
7
  import torch
8
  from huggingface_hub import snapshot_download
9
  from PIL import Image
10
+ from transformers import AutoModel
11
 
12
 
13
  os.environ["HF_ENDPOINT"] = "https://huggingface.co"
 
24
  repo_id=MODEL_ID,
25
  revision=MODEL_REVISION,
26
  )
27
+ if model_dir not in sys.path:
28
+ sys.path.insert(0, model_dir)
29
+
30
  model = AutoModel.from_pretrained(
31
  model_dir,
32
  trust_remote_code=True,
33
  )
34
+ tokenizer = importlib.import_module("tokenization_glm").GLMChineseTokenizer(
35
+ vocab_file=os.path.join(model_dir, "sp.model")
 
36
  )
37
+ image_processor = importlib.import_module(
38
+ "image_processing_m2_encoder"
39
+ ).M2EncoderImageProcessor.from_pretrained(model_dir)
40
  model.to(DEVICE)
41
  model.eval()
42
+ return model, tokenizer, image_processor
43
 
44
 
45
  def parse_labels(text: str):
 
59
  if not labels:
60
  raise ValueError("Please enter at least one label.")
61
 
62
+ model, tokenizer, image_processor = load_components()
63
  with torch.no_grad():
64
+ text_inputs = tokenizer(
65
+ labels,
66
+ padding="max_length",
67
+ truncation=True,
68
+ max_length=52,
69
+ return_special_tokens_mask=True,
70
+ return_tensors="pt",
71
+ )
72
+ image_inputs = image_processor(image.convert("RGB"), return_tensors="pt")
73
+ text_inputs = {
74
+ key: value.to(DEVICE) if hasattr(value, "to") else value
75
+ for key, value in text_inputs.items()
76
+ }
77
+ image_inputs = {
78
+ key: value.to(DEVICE) if hasattr(value, "to") else value
79
+ for key, value in image_inputs.items()
80
+ }
81
 
82
  text_outputs = model(**text_inputs)
83
  image_outputs = model(**image_inputs)