| | import os |
| | import torch |
| | import io |
| | from PIL import Image |
| | import matplotlib.pyplot as plt |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor |
| | import logging |
| | import time |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | model = None |
| | processor = None |
| | tokenizer = None |
| |
|
| | def describe_image(image_path=None, image_data=None, show_image=False): |
| | """ |
| | Load an image and generate a description using Qwen2-VL-7B model. |
| | |
| | Args: |
| | image_path (str, optional): Path to the image file |
| | image_data (bytes, optional): Raw image data |
| | show_image (bool): Whether to display the image |
| | |
| | Returns: |
| | dict: Descriptions of the image |
| | """ |
| | global model, processor, tokenizer |
| | |
| | |
| | if model is None or processor is None or tokenizer is None: |
| | load_model() |
| | |
| | |
| | if image_path is None and image_data is None: |
| | return {"error": "No image provided"} |
| | |
| | try: |
| | |
| | if image_path is not None: |
| | if not os.path.exists(image_path): |
| | return {"error": f"Image not found at {image_path}"} |
| | logger.info(f"Processing image from path: {image_path}") |
| | image = Image.open(image_path).convert('RGB') |
| | else: |
| | logger.info("Processing image from uploaded data") |
| | image = Image.open(io.BytesIO(image_data)).convert('RGB') |
| | |
| | |
| | if show_image: |
| | plt.figure(figsize=(10, 8)) |
| | plt.imshow(image) |
| | plt.axis('off') |
| | if image_path: |
| | plt.title(os.path.basename(image_path)) |
| | plt.show() |
| | |
| | |
| | logger.info("Generating descriptions...") |
| | |
| | |
| | pixel_values = processor(images=image, return_tensors="pt").to(model.device) |
| | |
| | |
| | prompt_basic = "Describe this image briefly." |
| | input_ids_basic = tokenizer(prompt_basic, return_tensors="pt").input_ids.to(model.device) |
| | |
| | |
| | prompt_detailed = "Analyze this image in detail. Describe the main elements, any text visible, the colors, and the overall composition." |
| | input_ids_detailed = tokenizer(prompt_detailed, return_tensors="pt").input_ids.to(model.device) |
| | |
| | |
| | prompt_technical = "What can you tell me about the technical aspects of this image?" |
| | input_ids_technical = tokenizer(prompt_technical, return_tensors="pt").input_ids.to(model.device) |
| | |
| | |
| | |
| | with torch.no_grad(): |
| | output_basic = model.generate( |
| | input_ids=input_ids_basic, |
| | pixel_values=pixel_values.pixel_values, |
| | max_new_tokens=150, |
| | do_sample=False |
| | ) |
| | basic_description = tokenizer.decode(output_basic[0], skip_special_tokens=True).replace(prompt_basic, "").strip() |
| | |
| | |
| | with torch.no_grad(): |
| | output_detailed = model.generate( |
| | input_ids=input_ids_detailed, |
| | pixel_values=pixel_values.pixel_values, |
| | max_new_tokens=300, |
| | do_sample=False |
| | ) |
| | detailed_description = tokenizer.decode(output_detailed[0], skip_special_tokens=True).replace(prompt_detailed, "").strip() |
| | |
| | |
| | with torch.no_grad(): |
| | output_technical = model.generate( |
| | input_ids=input_ids_technical, |
| | pixel_values=pixel_values.pixel_values, |
| | max_new_tokens=200, |
| | do_sample=False |
| | ) |
| | technical_analysis = tokenizer.decode(output_technical[0], skip_special_tokens=True).replace(prompt_technical, "").strip() |
| | |
| | return { |
| | "success": True, |
| | "basic_description": basic_description, |
| | "detailed_description": detailed_description, |
| | "technical_analysis": technical_analysis |
| | } |
| | |
| | except Exception as e: |
| | logger.error(f"Error processing image: {str(e)}", exc_info=True) |
| | return {"error": f"Error generating description: {str(e)}"} |
| |
|
| | def load_model(): |
| | """Load the model and related components""" |
| | global model, processor, tokenizer |
| | |
| | try: |
| | logger.info("Loading model...") |
| | model_id = "Qwen/Qwen2-VL-7B" |
| | |
| | |
| | processor = CLIPImageProcessor.from_pretrained(model_id) |
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | torch_dtype=torch.bfloat16, |
| | load_in_4bit=True, |
| | device_map="auto" |
| | ) |
| | logger.info("Model loaded successfully") |
| | return True |
| | except Exception as e: |
| | logger.error(f"Error loading model: {str(e)}", exc_info=True) |
| | return False |
| |
|
| | def main(): |
| | """Run in command-line mode""" |
| | |
| | image_folder = "data_temp" |
| | image_name = "page_2.png" |
| | image_path = os.path.join(image_folder, image_name) |
| | |
| | |
| | result = describe_image(image_path=image_path, show_image=True) |
| | |
| | |
| | if "error" not in result: |
| | print("\n==== Image Description Results (Qwen2-VL-7B) ====") |
| | print(f"\nBasic Description:\n{result['basic_description']}") |
| | print(f"\nDetailed Description:\n{result['detailed_description']}") |
| | print(f"\nTechnical Analysis:\n{result['technical_analysis']}") |
| | else: |
| | print(result["error"]) |
| |
|
| | if __name__ == "__main__": |
| | main() |