Ishgan commited on
Commit
457ec39
Β·
verified Β·
1 Parent(s): 1cc09bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -15
app.py CHANGED
@@ -1,19 +1,35 @@
1
- import torch
2
- from transformers import LlavaProcessor, LlavaForConditionalGeneration
3
- import gradio as gr
4
  from PIL import Image
5
 
6
- # Load LLaVA model and processor
7
- model_id = "ICTNLP/llava-mini-llama-3.1-8b"
8
- processor = LlavaProcessor.from_pretrained(model_id)
9
- model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Function to generate captions
12
- def generate_caption(image, prompt="Describe this image."):
13
- inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda")
14
- output = model.generate(**inputs, max_new_tokens=50)
15
- return processor.batch_decode(output, skip_special_tokens=True)[0]
16
 
17
- # Gradio UI
18
- demo = gr.Interface(fn=generate_caption, inputs=[gr.Image(type="pil"), gr.Textbox()], outputs="text")
19
- demo.launch()
 
1
+ import requests
 
 
2
  from PIL import Image
3
 
4
+ import torch
5
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
6
+
7
+ model_id = "llava-hf/llava-1.5-7b-hf"
8
+ model = LlavaForConditionalGeneration.from_pretrained(
9
+ model_id,
10
+ torch_dtype=torch.float16,
11
+ low_cpu_mem_usage=True,
12
+ ).to(0)
13
+
14
+ processor = AutoProcessor.from_pretrained(model_id)
15
+
16
+ # Define a chat history and use `apply_chat_template` to get correctly formatted prompt
17
+ # Each value in "content" has to be a list of dicts with types ("text", "image")
18
+ conversation = [
19
+ {
20
+
21
+ "role": "user",
22
+ "content": [
23
+ {"type": "text", "text": "What are these?"},
24
+ {"type": "image"},
25
+ ],
26
+ },
27
+ ]
28
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
29
 
30
+ image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
31
+ raw_image = Image.open(requests.get(image_file, stream=True).raw)
32
+ inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)
 
 
33
 
34
+ output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
35
+ print(processor.decode(output[0][2:], skip_special_tokens=True))