Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | |
| import torch | |
| from PIL import Image | |
| import requests | |
| import os | |
| # Load the model and tokenizer | |
| model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Define the image captioning function | |
| def generate_caption(image): | |
| pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
| pixel_values = pixel_values.to(device) | |
| output_ids = model.generate(pixel_values, max_length=16, num_beams=4) | |
| caption = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | |
| return caption.strip() | |
| # Function to fetch product information from GROQ API | |
| def fetch_product_info(caption): | |
| # Define your GROQ API endpoint | |
| api_url = "GroqCloud" # Replace with your actual GROQ API URL | |
| # Retrieve the API key from environment variables | |
| api_key = os.getenv('groq_api') # Use environment variable for API key | |
| if not api_key: | |
| st.error("API key not found. Set the 'GROQ_API_KEY' environment variable.") | |
| return None | |
| # Set up the headers with the Bearer token | |
| headers = {"Authorization": f"Bearer {api_key}"} | |
| # Query the API | |
| params = {"query": caption} | |
| try: | |
| response = requests.get(api_url, headers=headers, params=params) | |
| response.raise_for_status() # Raise an exception for HTTP errors | |
| data = response.json() | |
| if data: | |
| product = data.get('products', [])[0] # Assuming the first product is relevant | |
| ingredients = product.get('ingredients', 'N/A') | |
| usage = product.get('usage', 'N/A') | |
| barcode = product.get('barcode', 'N/A') | |
| return ingredients, usage, barcode | |
| else: | |
| st.write("No products found in the response.") | |
| return None | |
| except requests.RequestException as e: | |
| st.error(f"Error fetching data from GROQ API: {e}") | |
| return None | |
| # Streamlit UI | |
| st.title("Image Captioning and Product Information") | |
| st.write("Upload an image to get a caption and related product information.") | |
| # Upload image | |
| uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) | |
| if uploaded_image: | |
| image = Image.open(uploaded_image).convert("RGB") | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| # Generate and display caption | |
| with st.spinner("Generating caption..."): | |
| caption = generate_caption(image) | |
| st.write(f"**Caption:** {caption}") | |
| # Fetch and display additional information from GROQ API | |
| info = fetch_product_info(caption) | |
| if info: | |
| ingredients, usage, barcode = info | |
| st.write("### Additional Information:") | |
| st.write(f"**Ingredients:** {ingredients}") | |
| st.write(f"**Usage:** {usage}") | |
| st.write(f"**Barcoding:** {barcode}") | |
| else: | |
| st.write("No additional information found for this product.") | |