WaysAheadGlobal commited on
Commit
ef81d40
·
verified ·
1 Parent(s): 4265501

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -52
app.py CHANGED
@@ -1,62 +1,60 @@
1
  # app.py
2
 
3
  import gradio as gr
 
 
 
4
  import torch
5
- import cv2
6
  from PIL import Image
7
- from transformers import LlavaProcessor, LlavaForConditionalGeneration
8
 
9
- # Load LLaVA model (MiniGPT-4 style)
10
- model_id = "llava-hf/llava-1.5-7b-hf"
11
- processor = LlavaProcessor.from_pretrained(model_id)
12
- model = LlavaForConditionalGeneration.from_pretrained(model_id)
 
 
 
 
 
 
13
 
14
  device = torch.device("cpu")
15
  model.to(device)
16
 
17
- # Function: read webcam, yield frame + LLaVA caption every few seconds
18
- def webcam_llava():
19
- cap = cv2.VideoCapture(0)
20
- if not cap.isOpened():
21
- raise RuntimeError("Webcam could not be opened.")
22
-
23
- while True:
24
- ret, frame = cap.read()
25
- if not ret:
26
- break
27
-
28
- # Convert BGR to RGB PIL
29
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
30
- pil_image = Image.fromarray(rgb_frame)
31
-
32
- # --- Compose prompt for LLaVA ---
33
- prompt = "<image>\nUSER: Describe this scene in detail.\nASSISTANT:"
34
- inputs = processor(prompt, pil_image, return_tensors="pt").to(device)
35
-
36
- # Generate
37
- output = model.generate(**inputs, max_new_tokens=200)
38
- caption = processor.decode(output[0], skip_special_tokens=True)
39
-
40
- # Yield current frame + caption
41
- yield rgb_frame, caption
42
-
43
- # Wait before next frame (adjust as needed)
44
- cv2.waitKey(10000) # 10 seconds for CPU safety
45
-
46
- cap.release()
47
-
48
- # Gradio app
49
- with gr.Blocks() as demo:
50
- gr.Markdown("# 🎥 LLaVA MiniGPT-4 Webcam Captioning\n_(CPU, slow but descriptive)_")
51
-
52
- webcam_display = gr.Image(label="Live Webcam")
53
- description = gr.Textbox(label="LLaVA Caption")
54
-
55
- demo.load(
56
- fn=webcam_llava,
57
- inputs=None,
58
- outputs=[webcam_display, description],
59
- every=1
60
- )
61
-
62
- demo.launch()
 
1
  # app.py
2
 
3
  import gradio as gr
4
+ from tinyllava.model.builder import load_pretrained_model
5
+ from tinyllava.utils import disable_torch_init
6
+ from tinyllava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
7
  import torch
 
8
  from PIL import Image
 
9
 
10
+ # --- Disable unnecessary torch init ---
11
+ disable_torch_init()
12
+
13
+ # --- Load TinyLLaVA 3.1B ---
14
+ model_path = "bczhou/TinyLLaVA-3.1B" # official HF ID
15
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
16
+ model_path=model_path,
17
+ model_base=None, # If you have a base model, point it here; else leave as is
18
+ model_name="TinyLLaVA-3.1B"
19
+ )
20
 
21
  device = torch.device("cpu")
22
  model.to(device)
23
 
24
+ # --- Gradio handler ---
25
+ def describe_image(image, prompt):
26
+ # TinyLLaVA wants PIL
27
+ image = Image.fromarray(image)
28
+ image_tensor = process_images([image], image_processor, model.config)
29
+ image_tensor = image_tensor.to(device)
30
+
31
+ prompt = tokenizer_image_token(prompt, tokenizer, context_len)
32
+
33
+ inputs = tokenizer([prompt])
34
+ input_ids = torch.tensor(inputs.input_ids).unsqueeze(0).to(device)
35
+
36
+ with torch.no_grad():
37
+ output_ids = model.generate(
38
+ input_ids,
39
+ images=image_tensor,
40
+ do_sample=True,
41
+ temperature=0.2,
42
+ max_new_tokens=200
43
+ )
44
+
45
+ out_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
46
+ return out_text
47
+
48
+ iface = gr.Interface(
49
+ fn=describe_image,
50
+ inputs=[
51
+ gr.Image(type="numpy", label="Image"),
52
+ gr.Textbox(label="Your question", placeholder="What's happening in this image?")
53
+ ],
54
+ outputs=gr.Textbox(label="TinyLLaVA Answer"),
55
+ title="🦙 TinyLLaVA-3.1B — Vision-Language Q&A",
56
+ description="A lightweight LLaVA variant that runs on CPU Spaces. Upload an image, ask a question."
57
+ )
58
+
59
+ if __name__ == "__main__":
60
+ iface.launch()