786avinash commited on
Commit
68d7ce5
·
verified ·
1 Parent(s): b488bd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -43
app.py CHANGED
@@ -3,73 +3,79 @@ from transformers import BlipForQuestionAnswering, AutoProcessor
3
  from PIL import Image
4
  import gradio as gr
5
 
6
- # Load the BLIP model and processor
7
  model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
8
  processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
9
 
10
- # Define your Groq API key and endpoint
11
- groq_api_key = "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" # Replace with your Groq API key
12
- groq_api_url = "https://api.groq.com/openai/v1/chat/completions" # Replace with the appropriate Groq endpoint
13
 
14
- def qna(image, question, context):
15
  try:
16
- # Step 1: Get initial short answer from BLIP
17
  inputs = processor(image, question, return_tensors="pt")
18
  out = model.generate(**inputs)
19
  short_answer = processor.decode(out[0], skip_special_tokens=True)
 
 
 
 
20
 
21
- # Step 2: Construct prompt for Groq API
22
- full_prompt = f"{context}\nUser: {question}\nBLIP: {short_answer}\nAssistant:"
 
23
 
24
- # Step 3: Send prompt to Groq API for a detailed answer
25
  headers = {
26
  "Authorization": f"Bearer {groq_api_key}",
27
  "Content-Type": "application/json"
28
  }
29
 
30
  data = {
31
- "model": "llama3-8b-8192", # Specify the model to use
32
  "messages": [{"role": "user", "content": full_prompt}]
33
  }
34
 
35
  response = requests.post(groq_api_url, headers=headers, json=data)
36
-
37
- # Check if the response is successful
38
  if response.status_code == 200:
39
- detailed_answer = response.json().get('choices', [])[0].get('message', {}).get('content', '').strip()
40
- # Update the context with the latest question and answer
41
- updated_context = f"{context}\nUser: {question}\nAssistant: {detailed_answer}"
42
- return updated_context, updated_context # Return updated context for display
43
  else:
44
- return f"Error {response.status_code}: {response.text}", context
45
-
 
 
46
  except Exception as e:
47
- return f"An error occurred: {str(e)}", context
 
 
48
 
49
- # Create Gradio interface with context management
50
- def chatbot_interface(image, question, context=""):
51
- # Initialize context if image is uploaded
52
- if context == "" and image is not None:
53
- context = "" # Reset context when the image is first uploaded
54
 
55
- # Get the answer from the model
56
- answer, updated_context = qna(image, question, context)
57
-
58
- # Return the updated context for display
59
- return updated_context
60
-
61
- # Define the Gradio interface
62
- interf = gr.Interface(
63
- fn=chatbot_interface,
64
- inputs=[
65
- gr.Image(type="pil", label="Upload Image"),
66
- gr.Textbox(label="Ask a question")
67
- ],
68
- outputs="text", # Output the full conversation context
69
- title="Interactive Image Chatbot",
70
- description="Upload an image and have a conversation about it. Ask multiple questions about the image."
71
- )
 
 
 
 
 
 
 
72
 
73
- # Launch the interface
74
  if __name__ == "__main__":
75
- interf.launch()
 
3
  from PIL import Image
4
  import gradio as gr
5
 
 
6
  model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
7
  processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
8
 
9
+ groq_api_key = "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"
10
+ groq_api_url = "https://api.groq.com/openai/v1/chat/completions"
 
11
 
12
+ def qna(image, question, history):
13
  try:
 
14
  inputs = processor(image, question, return_tensors="pt")
15
  out = model.generate(**inputs)
16
  short_answer = processor.decode(out[0], skip_special_tokens=True)
17
+
18
+ context = "\n".join([f"Q: {q}\nA: {a}" for q, a in history])
19
+ full_prompt = f"""Context of previous conversation:
20
+ {context}
21
 
22
+ Current Image Description: {short_answer}
23
+ Question: {question}
24
+ Please provide a detailed answer based on the image and previous context."""
25
 
 
26
  headers = {
27
  "Authorization": f"Bearer {groq_api_key}",
28
  "Content-Type": "application/json"
29
  }
30
 
31
  data = {
32
+ "model": "llama3-8b-8192",
33
  "messages": [{"role": "user", "content": full_prompt}]
34
  }
35
 
36
  response = requests.post(groq_api_url, headers=headers, json=data)
37
+
 
38
  if response.status_code == 200:
39
+ detailed_answer = response.json()['choices'][0]['message']['content'].strip()
40
+ history.append((question, detailed_answer))
41
+ return history, history
 
42
  else:
43
+ error_msg = f"Error {response.status_code}: {response.text}"
44
+ history.append((question, error_msg))
45
+ return history, history
46
+
47
  except Exception as e:
48
+ error_msg = f"An error occurred: {str(e)}"
49
+ history.append((question, error_msg))
50
+ return history, history
51
 
52
+ def clear_history():
53
+ return [], []
 
 
 
54
 
55
+ with gr.Blocks() as demo:
56
+ gr.Markdown("# Interactive Image Chatbot")
57
+
58
+ with gr.Row():
59
+ image_input = gr.Image(type="pil")
60
+
61
+ with gr.Row():
62
+ with gr.Column():
63
+ chatbot = gr.Chatbot()
64
+ question = gr.Textbox(label="Ask a question about the image")
65
+ clear = gr.Button("Clear Conversation")
66
+
67
+ state = gr.State([])
68
+
69
+ question.submit(
70
+ qna,
71
+ inputs=[image_input, question, state],
72
+ outputs=[chatbot, state]
73
+ )
74
+
75
+ clear.click(
76
+ clear_history,
77
+ outputs=[chatbot, state]
78
+ )
79
 
 
80
  if __name__ == "__main__":
81
+ demo.launch()