ravi-vc commited on
Commit
502a6ed
Β·
verified Β·
1 Parent(s): 39daee9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq
3
+ from PIL import Image
4
+ import torch
5
+
6
+ # Load LLaVA (Apache-2.0 license, safe to use)
7
+ model_id = "llava-hf/llava-1.5-7b-hf"
8
+
9
+ processor = AutoProcessor.from_pretrained(model_id)
10
+ model = AutoModelForVision2Seq.from_pretrained(
11
+ model_id,
12
+ torch_dtype=torch.float16,
13
+ low_cpu_mem_usage=True,
14
+ ).to("cuda")
15
+
16
+ def analyze_chart(image, question="Describe this chart"):
17
+ # Preprocess
18
+ inputs = processor(images=image, text=question, return_tensors="pt").to("cuda", torch.float16)
19
+
20
+ # Generate
21
+ output_ids = model.generate(**inputs, max_new_tokens=200)
22
+ response = processor.decode(output_ids[0], skip_special_tokens=True)
23
+ return response
24
+
25
+ demo = gr.Interface(
26
+ fn=analyze_chart,
27
+ inputs=[gr.Image(type="pil"), gr.Textbox(value="Describe this chart")],
28
+ outputs="text",
29
+ title="Chart Analyzer (LLaVA)",
30
+ description="Upload a chart and ask questions like 'What does this chart show?' or 'Which month has highest sales?'."
31
+ )
32
+
33
+ if __name__ == "__main__":
34
+ demo.launch()