arjunanand13 commited on
Commit
d654351
·
verified ·
1 Parent(s): d6c8e75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from PIL import Image
3
- from transformers import AutoProcessor, AutoModelForCausalLM
4
  import gradio as gr
5
  import json
6
  import traceback
@@ -10,7 +10,7 @@ model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
10
  token = os.getenv("HUGGINGFACE_TOKEN").strip()
11
 
12
  processor = AutoProcessor.from_pretrained(model_name, token=token)
13
- model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
  quantization_config={"load_in_4bit": True},
16
  token=token
@@ -36,15 +36,15 @@ def analyze_image(image, prompt):
36
  return_tensors="pt"
37
  ).to(model.device)
38
 
39
- # Separate inputs for generate method
40
- generate_inputs = {
41
- k: v for k, v in inputs.items()
42
- if k not in ['pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask']
43
- }
44
 
45
  with torch.no_grad():
46
  output = model.generate(**generate_inputs, max_new_tokens=100)
47
-
48
  result = processor.decode(output[0], skip_special_tokens=True)
49
 
50
  try:
 
1
  import torch
2
  from PIL import Image
3
+ from transformers import AutoProcessor, AutoModelForPreTraining
4
  import gradio as gr
5
  import json
6
  import traceback
 
10
  token = os.getenv("HUGGINGFACE_TOKEN").strip()
11
 
12
  processor = AutoProcessor.from_pretrained(model_name, token=token)
13
+ model = AutoModelForPreTraining.from_pretrained(
14
  model_name,
15
  quantization_config={"load_in_4bit": True},
16
  token=token
 
36
  return_tensors="pt"
37
  ).to(model.device)
38
 
39
+ # # Separate inputs for generate method
40
+ # generate_inputs = {
41
+ # k: v for k, v in inputs.items()
42
+ # if k not in ['pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask']
43
+ # }
44
 
45
  with torch.no_grad():
46
  output = model.generate(**generate_inputs, max_new_tokens=100)
47
+ print(processor.decode(output[0]))
48
  result = processor.decode(output[0], skip_special_tokens=True)
49
 
50
  try: