File size: 5,659 Bytes
a2e3298
 
 
 
 
 
 
 
 
 
cde2f6e
 
a2e3298
 
 
 
 
 
 
 
 
 
 
cde2f6e
a2e3298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cde2f6e
a2e3298
 
 
 
 
cde2f6e
 
a2e3298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cde2f6e
 
 
 
 
 
 
 
 
 
 
 
a2e3298
 
 
 
 
 
cde2f6e
a2e3298
 
 
 
cde2f6e
a2e3298
 
 
 
 
 
cde2f6e
 
a2e3298
cde2f6e
 
 
 
 
 
 
 
 
 
 
 
 
 
a2e3298
 
 
 
 
 
 
 
cde2f6e
a2e3298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import logging
import base64
import io
from typing import Optional, Dict, Any
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava15ChatHandler
from huggingface_hub import hf_hub_download
from PIL import Image

from config import config
# ADD THIS IMPORT
from utils.json_extractor import extract_json_from_content

logger = logging.getLogger("vision-service")

class VisionService:
    """Service for vision-language model interactions"""
    
    def __init__(self):
        self.model: Optional[Llama] = None
        self.chat_handler: Optional[Llava15ChatHandler] = None
        
    async def initialize(self) -> None:
        # ... (Same as your original code) ...
        try:
            logger.info(f"Downloading vision model: {config.VISION_MODEL_FILE}...")
            model_path = hf_hub_download(
                repo_id=config.VISION_MODEL_REPO,
                filename=config.VISION_MODEL_FILE,
                cache_dir=config.HF_HOME
            )
            
            logger.info(f"Downloading vision projector: {config.VISION_MMPROJ_FILE}...")
            mmproj_path = hf_hub_download(
                repo_id=config.VISION_MODEL_REPO,
                filename=config.VISION_MMPROJ_FILE,
                cache_dir=config.HF_HOME
            )
            
            logger.info(f"Loading vision model (Threads: {config.N_THREADS})...")
            
            self.chat_handler = Llava15ChatHandler(
                clip_model_path=mmproj_path,
                verbose=False
            )
            
            self.model = Llama(
                model_path=model_path,
                chat_handler=self.chat_handler,
                n_ctx=config.VISION_MODEL_CTX,
                n_threads=config.N_THREADS,
                n_batch=config.VISION_MODEL_BATCH,
                logits_all=True,
                verbose=False
            )
            logger.info("✓ Vision model loaded successfully")
            
        except Exception as e:
            logger.error(f"Failed to initialize vision model: {e}")
            raise
    
    def is_ready(self) -> bool:
        return self.model is not None and self.chat_handler is not None
    
    # UPDATED METHOD
    async def analyze_image(
        self,
        image_data: bytes,
        prompt: str,
        temperature: float = 0.6,
        max_tokens: int = 512,
        return_json: bool = False  # Added parameter
    ) -> Dict[str, Any]:
        """
        Analyze an image with a text prompt
        """
        if not self.is_ready():
            raise RuntimeError("Vision model not initialized")
        
        try:
            # Convert image bytes to base64 data URI
            image_b64 = base64.b64encode(image_data).decode('utf-8')
            
            # Validate image
            image = Image.open(io.BytesIO(image_data))
            logger.info(f"Processing image: {image.size} | Format: {image.format}")
            
            # Modify prompt if return_json is requested
            # Note: For LLaVA/Vision models, it is often safer to append the system instruction 
            # to the user text rather than a separate system role message.
            final_prompt = prompt
            if return_json:
                final_prompt += (
                    "\n\nYou are a strict JSON generator. "
                    "Convert the output into valid JSON format. "
                    "Output strictly in markdown code blocks like ```json ... ```. "
                    "Do not add conversational filler."
                )

            # Create vision message format
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}},
                        {"type": "text", "text": final_prompt}
                    ]
                }
            ]
            
            logger.info(f"Analyzing image with prompt: {prompt[:50]}... | JSON: {return_json}")
            
            response = self.model.create_chat_completion(
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens
            )

            content_text = response['choices'][0]['message']['content']
            
            # Logic for return_json
            if return_json:
                extracted_data = extract_json_from_content(content_text)
                return {
                    "status": "success",
                    "data": extracted_data,
                    "image_info": {
                        "size": list(image.size),
                        "format": image.format
                    },
                    "usage": response.get('usage', {})
                }

            # Standard return
            return {
                "status": "success",
                "image_info": {
                    "size": list(image.size),
                    "format": image.format,
                    "mode": image.mode
                },
                "prompt": prompt,
                "response": content_text,
                "usage": response.get('usage', {})
            }
            
        except Exception as e:
            logger.error(f"Error analyzing image: {e}")
            raise
    
    async def cleanup(self) -> None:
        if self.model:
            del self.model
            self.model = None
        if self.chat_handler:
            del self.chat_handler
            self.chat_handler = None
        logger.info("Vision model unloaded")

# Global instance
vision_service = VisionService()