Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| import google.generativeai as genai | |
| from PIL import Image | |
| from io import BytesIO | |
| from typing import List, Union | |
| import logging | |
| from dotenv import load_dotenv | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| # Configure Gemini API for image processing | |
| genai.configure(api_key=os.getenv("GEMINI_API_KEY_IMAGE")) | |
| def load_image(image_source: str) -> Image.Image: | |
| """Load image from a URL or local path.""" | |
| try: | |
| if image_source.startswith("http://") or image_source.startswith("https://"): | |
| logger.info(f"Loading image from URL: {image_source}") | |
| response = requests.get(image_source, timeout=30) | |
| response.raise_for_status() | |
| return Image.open(BytesIO(response.content)).convert("RGB") | |
| elif os.path.isfile(image_source): | |
| logger.info(f"Loading image from file: {image_source}") | |
| return Image.open(image_source).convert("RGB") | |
| else: | |
| raise ValueError("Invalid image source: must be a valid URL or file path") | |
| except Exception as e: | |
| logger.error(f"Failed to load image from {image_source}: {e}") | |
| raise RuntimeError(f"Failed to load image: {e}") | |
| def get_answer_for_image(image_source: str, questions: List[str], retries: int = 3) -> List[str]: | |
| """Ask questions about an image using Gemini Vision model.""" | |
| try: | |
| logger.info(f"Processing image with {len(questions)} questions") | |
| image = load_image(image_source) | |
| prompt = """ | |
| Answer the following questions about the image. Give the answers in the same order as the questions. | |
| Answers should be descriptive and detailed. Give one answer per line with numbering as "1. 2. 3. ..". | |
| Example answer format: | |
| 1. Answer 1, Explanation | |
| 2. Answer 2, Explanation | |
| 3. Answer 3, Explanation | |
| Questions: | |
| """ | |
| prompt += "\n".join(f"{i+1}. {q}" for i, q in enumerate(questions)) | |
| model = genai.GenerativeModel("gemini-1.5-flash") | |
| for attempt in range(retries): | |
| try: | |
| logger.info(f"Attempt {attempt + 1} of {retries} to get response from Gemini") | |
| response = model.generate_content( | |
| [prompt, image], | |
| generation_config=genai.types.GenerationConfig( | |
| temperature=0.4, | |
| max_output_tokens=2048 | |
| ) | |
| ) | |
| raw_text = response.text.strip() | |
| logger.info(f"Received response from Gemini: {len(raw_text)} characters") | |
| answers = extract_ordered_answers(raw_text, len(questions)) | |
| if len(answers) == len(questions): | |
| logger.info(f"Successfully extracted {len(answers)} answers") | |
| return answers | |
| else: | |
| logger.warning(f"Expected {len(questions)} answers, got {len(answers)}") | |
| except Exception as e: | |
| logger.error(f"Attempt {attempt + 1} failed: {e}") | |
| if attempt == retries - 1: | |
| raise RuntimeError(f"Failed after {retries} attempts: {e}") | |
| raise RuntimeError("Failed to get valid response from Gemini.") | |
| except Exception as e: | |
| logger.error(f"Error in get_answer_for_image: {e}") | |
| raise | |
| def extract_ordered_answers(raw_text: str, expected_count: int) -> List[str]: | |
| """Parse the raw Gemini output into a clean list of answers.""" | |
| import re | |
| logger.debug(f"Extracting {expected_count} answers from raw text") | |
| lines = raw_text.splitlines() | |
| answers = [] | |
| for line in lines: | |
| # Match numbered lines: "1. Answer", "1) Answer", "1 - Answer", etc. | |
| match = re.match(r"^\s*(\d+)[\).\s-]*\s*(.+)", line) | |
| if match: | |
| answer_text = match.group(2).strip() | |
| if answer_text: # Only add non-empty answers | |
| answers.append(answer_text) | |
| # Fallback: if numbering failed, use plain lines | |
| if len(answers) < expected_count: | |
| logger.warning("Numbered extraction failed, using fallback method") | |
| answers = [line.strip() for line in lines if line.strip()] | |
| # Return exactly the expected number of answers | |
| result = answers[:expected_count] | |
| # If we still don't have enough answers, pad with error messages | |
| while len(result) < expected_count: | |
| result.append("Unable to extract answer from image") | |
| logger.info(f"Extracted {len(result)} answers") | |
| return result | |
| def process_image_query(image_path: str, query: str) -> str: | |
| """Process a single query about an image.""" | |
| try: | |
| questions = [query] | |
| answers = get_answer_for_image(image_path, questions) | |
| return answers[0] if answers else "Unable to process image query" | |
| except Exception as e: | |
| logger.error(f"Error processing image query: {e}") | |
| return f"Error processing image: {str(e)}" | |
| def process_multiple_image_queries(image_path: str, queries: List[str]) -> List[str]: | |
| """Process multiple queries about an image.""" | |
| try: | |
| return get_answer_for_image(image_path, queries) | |
| except Exception as e: | |
| logger.error(f"Error processing multiple image queries: {e}") | |
| return [f"Error processing image: {str(e)}"] * len(queries) |