Spaces:
Running
Running
| import logging | |
| import base64 | |
| import io | |
| from typing import Optional, Dict, Any | |
| from llama_cpp import Llama | |
| from llama_cpp.llama_chat_format import Llava15ChatHandler | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from config import config | |
| # ADD THIS IMPORT | |
| from utils.json_extractor import extract_json_from_content | |
| logger = logging.getLogger("vision-service") | |
| class VisionService: | |
| """Service for vision-language model interactions""" | |
| def __init__(self): | |
| self.model: Optional[Llama] = None | |
| self.chat_handler: Optional[Llava15ChatHandler] = None | |
| async def initialize(self) -> None: | |
| # ... (Same as your original code) ... | |
| try: | |
| logger.info(f"Downloading vision model: {config.VISION_MODEL_FILE}...") | |
| model_path = hf_hub_download( | |
| repo_id=config.VISION_MODEL_REPO, | |
| filename=config.VISION_MODEL_FILE, | |
| cache_dir=config.HF_HOME | |
| ) | |
| logger.info(f"Downloading vision projector: {config.VISION_MMPROJ_FILE}...") | |
| mmproj_path = hf_hub_download( | |
| repo_id=config.VISION_MODEL_REPO, | |
| filename=config.VISION_MMPROJ_FILE, | |
| cache_dir=config.HF_HOME | |
| ) | |
| logger.info(f"Loading vision model (Threads: {config.N_THREADS})...") | |
| self.chat_handler = Llava15ChatHandler( | |
| clip_model_path=mmproj_path, | |
| verbose=False | |
| ) | |
| self.model = Llama( | |
| model_path=model_path, | |
| chat_handler=self.chat_handler, | |
| n_ctx=config.VISION_MODEL_CTX, | |
| n_threads=config.N_THREADS, | |
| n_batch=config.VISION_MODEL_BATCH, | |
| logits_all=True, | |
| verbose=False | |
| ) | |
| logger.info("✓ Vision model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize vision model: {e}") | |
| raise | |
| def is_ready(self) -> bool: | |
| return self.model is not None and self.chat_handler is not None | |
| # UPDATED METHOD | |
| async def analyze_image( | |
| self, | |
| image_data: bytes, | |
| prompt: str, | |
| temperature: float = 0.6, | |
| max_tokens: int = 512, | |
| return_json: bool = False # Added parameter | |
| ) -> Dict[str, Any]: | |
| """ | |
| Analyze an image with a text prompt | |
| """ | |
| if not self.is_ready(): | |
| raise RuntimeError("Vision model not initialized") | |
| try: | |
| # Convert image bytes to base64 data URI | |
| image_b64 = base64.b64encode(image_data).decode('utf-8') | |
| # Validate image | |
| image = Image.open(io.BytesIO(image_data)) | |
| logger.info(f"Processing image: {image.size} | Format: {image.format}") | |
| # Modify prompt if return_json is requested | |
| # Note: For LLaVA/Vision models, it is often safer to append the system instruction | |
| # to the user text rather than a separate system role message. | |
| final_prompt = prompt | |
| if return_json: | |
| final_prompt += ( | |
| "\n\nYou are a strict JSON generator. " | |
| "Convert the output into valid JSON format. " | |
| "Output strictly in markdown code blocks like ```json ... ```. " | |
| "Do not add conversational filler." | |
| ) | |
| # Create vision message format | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}}, | |
| {"type": "text", "text": final_prompt} | |
| ] | |
| } | |
| ] | |
| logger.info(f"Analyzing image with prompt: {prompt[:50]}... | JSON: {return_json}") | |
| response = self.model.create_chat_completion( | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| content_text = response['choices'][0]['message']['content'] | |
| # Logic for return_json | |
| if return_json: | |
| extracted_data = extract_json_from_content(content_text) | |
| return { | |
| "status": "success", | |
| "data": extracted_data, | |
| "image_info": { | |
| "size": list(image.size), | |
| "format": image.format | |
| }, | |
| "usage": response.get('usage', {}) | |
| } | |
| # Standard return | |
| return { | |
| "status": "success", | |
| "image_info": { | |
| "size": list(image.size), | |
| "format": image.format, | |
| "mode": image.mode | |
| }, | |
| "prompt": prompt, | |
| "response": content_text, | |
| "usage": response.get('usage', {}) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error analyzing image: {e}") | |
| raise | |
| async def cleanup(self) -> None: | |
| if self.model: | |
| del self.model | |
| self.model = None | |
| if self.chat_handler: | |
| del self.chat_handler | |
| self.chat_handler = None | |
| logger.info("Vision model unloaded") | |
| # Global instance | |
| vision_service = VisionService() |