Ishgan commited on
Commit
0c077d2
Β·
verified Β·
1 Parent(s): 43b8a22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -18
app.py CHANGED
@@ -1,41 +1,48 @@
1
  import requests
2
  from PIL import Image
3
-
4
  import torch
5
-
6
  from transformers import AutoProcessor, LlavaForConditionalGeneration
7
 
 
 
8
 
 
 
9
 
10
- model_id = "llava-hf/llava-1.5-7b-hf"
11
  model = LlavaForConditionalGeneration.from_pretrained(
12
  model_id,
13
  torch_dtype=torch.float16,
14
- low_cpu_mem_usage=True,
15
- ).to(0)
16
-
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- model = model.to(device)
19
 
 
20
  processor = AutoProcessor.from_pretrained(model_id)
21
 
22
- # Define a chat history and use `apply_chat_template` to get correctly formatted prompt
23
- # Each value in "content" has to be a list of dicts with types ("text", "image")
24
  conversation = [
25
  {
26
-
27
- "role": "user",
28
- "content": [
29
- {"type": "text", "text": "What are these?"},
30
- {"type": "image"},
31
  ],
32
  },
33
  ]
 
 
34
  prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
35
 
36
- image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
37
- raw_image = Image.open(requests.get(image_file, stream=True).raw)
38
- inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)
39
 
 
 
 
 
 
40
  output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
 
 
41
  print(processor.decode(output[0][2:], skip_special_tokens=True))
 
1
  import requests
2
  from PIL import Image
 
3
  import torch
 
4
  from transformers import AutoProcessor, LlavaForConditionalGeneration
5
 
6
+ # Model ID
7
+ model_id = "llava-hf/llava-1.5-7b-hf"
8
 
9
+ # Set device
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # Load model onto the correct device
13
  model = LlavaForConditionalGeneration.from_pretrained(
14
  model_id,
15
  torch_dtype=torch.float16,
16
+ low_cpu_mem_usage=True
17
+ ).to(device)
 
 
 
18
 
19
+ # Load processor
20
  processor = AutoProcessor.from_pretrained(model_id)
21
 
22
+ # Define conversation
 
23
  conversation = [
24
  {
25
+ "role": "user",
26
+ "content": [
27
+ {"type": "text", "text": "What are these?"},
28
+ {"type": "image"},
 
29
  ],
30
  },
31
  ]
32
+
33
+ # Apply chat template
34
  prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
35
 
36
+ # Load image
37
+ image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
38
+ raw_image = Image.open(requests.get(image_url, stream=True).raw)
39
 
40
+ # Preprocess inputs
41
+ inputs = processor(images=raw_image, text=prompt, return_tensors='pt')
42
+ inputs = {k: v.to(device, torch.float16) for k, v in inputs.items()}
43
+
44
+ # Generate output
45
  output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
46
+
47
+ # Decode and print result
48
  print(processor.decode(output[0][2:], skip_special_tokens=True))