image / app.py
Arfa-ilyas's picture
Update app.py
314c39a verified
import streamlit as st
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
import requests
import os
# Load the model and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define the image captioning function
def generate_caption(image):
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
output_ids = model.generate(pixel_values, max_length=16, num_beams=4)
caption = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
return caption.strip()
# Function to fetch product information from GROQ API
def fetch_product_info(caption):
# Define your GROQ API endpoint
api_url = "GroqCloud" # Replace with your actual GROQ API URL
# Retrieve the API key from environment variables
api_key = os.getenv('groq_api') # Use environment variable for API key
if not api_key:
st.error("API key not found. Set the 'GROQ_API_KEY' environment variable.")
return None
# Set up the headers with the Bearer token
headers = {"Authorization": f"Bearer {api_key}"}
# Query the API
params = {"query": caption}
try:
response = requests.get(api_url, headers=headers, params=params)
response.raise_for_status() # Raise an exception for HTTP errors
data = response.json()
if data:
product = data.get('products', [])[0] # Assuming the first product is relevant
ingredients = product.get('ingredients', 'N/A')
usage = product.get('usage', 'N/A')
barcode = product.get('barcode', 'N/A')
return ingredients, usage, barcode
else:
st.write("No products found in the response.")
return None
except requests.RequestException as e:
st.error(f"Error fetching data from GROQ API: {e}")
return None
# Streamlit UI
st.title("Image Captioning and Product Information")
st.write("Upload an image to get a caption and related product information.")
# Upload image
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
if uploaded_image:
image = Image.open(uploaded_image).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
# Generate and display caption
with st.spinner("Generating caption..."):
caption = generate_caption(image)
st.write(f"**Caption:** {caption}")
# Fetch and display additional information from GROQ API
info = fetch_product_info(caption)
if info:
ingredients, usage, barcode = info
st.write("### Additional Information:")
st.write(f"**Ingredients:** {ingredients}")
st.write(f"**Usage:** {usage}")
st.write(f"**Barcoding:** {barcode}")
else:
st.write("No additional information found for this product.")