WaysAheadGlobal commited on
Commit
9aa78bc
Β·
verified Β·
1 Parent(s): d8f06bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -4,7 +4,7 @@ import streamlit as st
4
  from PIL import Image
5
  import torch
6
 
7
- # Import TinyLLaVA modules (use local copy!)
8
  from tinyllava.model.builder import load_pretrained_model
9
  from tinyllava.utils import disable_torch_init
10
  from tinyllava.mm_utils import (
@@ -13,11 +13,13 @@ from tinyllava.mm_utils import (
13
  get_model_name_from_path
14
  )
15
 
16
- # Disable torch default init for speed
17
  disable_torch_init()
18
 
19
- # Load TinyLLaVA 3.1B
20
  MODEL_PATH = "bczhou/TinyLLaVA-3.1B"
 
 
21
  tokenizer, model, image_processor, context_len = load_pretrained_model(
22
  model_path=MODEL_PATH,
23
  model_base=None,
@@ -31,9 +33,8 @@ model.to(device)
31
  st.set_page_config(page_title="TinyLLaVA 3.1B (Streamlit)", layout="centered")
32
  st.title("πŸ¦™ TinyLLaVA 3.1B β€” Vision-Language Q&A")
33
 
34
- uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
35
-
36
- prompt = st.text_input("Ask a question about the image:")
37
 
38
  if uploaded_file is not None and prompt:
39
  image = Image.open(uploaded_file).convert("RGB")
@@ -42,12 +43,12 @@ if uploaded_file is not None and prompt:
42
  image_tensor = process_images([image], image_processor, model.config)
43
  image_tensor = image_tensor.to(device)
44
 
45
- # Process prompt
46
  prompt_text = tokenizer_image_token(prompt, tokenizer, context_len)
47
  inputs = tokenizer([prompt_text])
48
  input_ids = torch.tensor(inputs.input_ids).unsqueeze(0).to(device)
49
 
50
- # Run inference
51
  with st.spinner("Generating answer..."):
52
  output_ids = model.generate(
53
  input_ids,
@@ -58,5 +59,5 @@ if uploaded_file is not None and prompt:
58
  )
59
  out_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
60
 
61
- st.subheader("Answer:")
62
  st.write(out_text)
 
4
  from PIL import Image
5
  import torch
6
 
7
+ # βœ… Local TinyLLaVA from real LLaVA repo
8
  from tinyllava.model.builder import load_pretrained_model
9
  from tinyllava.utils import disable_torch_init
10
  from tinyllava.mm_utils import (
 
13
  get_model_name_from_path
14
  )
15
 
16
+ # Disable torch default init for faster startup
17
  disable_torch_init()
18
 
19
+ # Load TinyLLaVA 3.1B (best small version)
20
  MODEL_PATH = "bczhou/TinyLLaVA-3.1B"
21
+
22
+ # Loads tokenizer, model, image processor, context length
23
  tokenizer, model, image_processor, context_len = load_pretrained_model(
24
  model_path=MODEL_PATH,
25
  model_base=None,
 
33
  st.set_page_config(page_title="TinyLLaVA 3.1B (Streamlit)", layout="centered")
34
  st.title("πŸ¦™ TinyLLaVA 3.1B β€” Vision-Language Q&A")
35
 
36
+ uploaded_file = st.file_uploader("πŸ“· Upload an image", type=["jpg", "png", "jpeg"])
37
+ prompt = st.text_input("πŸ’¬ Ask a question about the image:")
 
38
 
39
  if uploaded_file is not None and prompt:
40
  image = Image.open(uploaded_file).convert("RGB")
 
43
  image_tensor = process_images([image], image_processor, model.config)
44
  image_tensor = image_tensor.to(device)
45
 
46
+ # Build prompt with image tokens
47
  prompt_text = tokenizer_image_token(prompt, tokenizer, context_len)
48
  inputs = tokenizer([prompt_text])
49
  input_ids = torch.tensor(inputs.input_ids).unsqueeze(0).to(device)
50
 
51
+ # Generate
52
  with st.spinner("Generating answer..."):
53
  output_ids = model.generate(
54
  input_ids,
 
59
  )
60
  out_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
61
 
62
+ st.subheader("πŸ“ Answer:")
63
  st.write(out_text)