import streamlit as st from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer import torch from PIL import Image import requests import os # Load the model and tokenizer model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Define the image captioning function def generate_caption(image): pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) output_ids = model.generate(pixel_values, max_length=16, num_beams=4) caption = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] return caption.strip() # Function to fetch product information from GROQ API def fetch_product_info(caption): # Define your GROQ API endpoint api_url = "GroqCloud" # Replace with your actual GROQ API URL # Retrieve the API key from environment variables api_key = os.getenv('groq_api') # Use environment variable for API key if not api_key: st.error("API key not found. Set the 'GROQ_API_KEY' environment variable.") return None # Set up the headers with the Bearer token headers = {"Authorization": f"Bearer {api_key}"} # Query the API params = {"query": caption} try: response = requests.get(api_url, headers=headers, params=params) response.raise_for_status() # Raise an exception for HTTP errors data = response.json() if data: product = data.get('products', [])[0] # Assuming the first product is relevant ingredients = product.get('ingredients', 'N/A') usage = product.get('usage', 'N/A') barcode = product.get('barcode', 'N/A') return ingredients, usage, barcode else: st.write("No products found in the response.") return None except requests.RequestException as e: st.error(f"Error fetching data from GROQ API: {e}") return None # Streamlit UI st.title("Image Captioning and Product Information") st.write("Upload an image to get a caption and related product information.") # Upload image uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) if uploaded_image: image = Image.open(uploaded_image).convert("RGB") st.image(image, caption="Uploaded Image", use_column_width=True) # Generate and display caption with st.spinner("Generating caption..."): caption = generate_caption(image) st.write(f"**Caption:** {caption}") # Fetch and display additional information from GROQ API info = fetch_product_info(caption) if info: ingredients, usage, barcode = info st.write("### Additional Information:") st.write(f"**Ingredients:** {ingredients}") st.write(f"**Usage:** {usage}") st.write(f"**Barcoding:** {barcode}") else: st.write("No additional information found for this product.")