ai_chat_api / services /vision_service.py
Soumik Bose
go
cde2f6e
import logging
import base64
import io
from typing import Optional, Dict, Any
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava15ChatHandler
from huggingface_hub import hf_hub_download
from PIL import Image
from config import config
# ADD THIS IMPORT
from utils.json_extractor import extract_json_from_content
logger = logging.getLogger("vision-service")
class VisionService:
"""Service for vision-language model interactions"""
def __init__(self):
self.model: Optional[Llama] = None
self.chat_handler: Optional[Llava15ChatHandler] = None
async def initialize(self) -> None:
# ... (Same as your original code) ...
try:
logger.info(f"Downloading vision model: {config.VISION_MODEL_FILE}...")
model_path = hf_hub_download(
repo_id=config.VISION_MODEL_REPO,
filename=config.VISION_MODEL_FILE,
cache_dir=config.HF_HOME
)
logger.info(f"Downloading vision projector: {config.VISION_MMPROJ_FILE}...")
mmproj_path = hf_hub_download(
repo_id=config.VISION_MODEL_REPO,
filename=config.VISION_MMPROJ_FILE,
cache_dir=config.HF_HOME
)
logger.info(f"Loading vision model (Threads: {config.N_THREADS})...")
self.chat_handler = Llava15ChatHandler(
clip_model_path=mmproj_path,
verbose=False
)
self.model = Llama(
model_path=model_path,
chat_handler=self.chat_handler,
n_ctx=config.VISION_MODEL_CTX,
n_threads=config.N_THREADS,
n_batch=config.VISION_MODEL_BATCH,
logits_all=True,
verbose=False
)
logger.info("✓ Vision model loaded successfully")
except Exception as e:
logger.error(f"Failed to initialize vision model: {e}")
raise
def is_ready(self) -> bool:
return self.model is not None and self.chat_handler is not None
# UPDATED METHOD
async def analyze_image(
self,
image_data: bytes,
prompt: str,
temperature: float = 0.6,
max_tokens: int = 512,
return_json: bool = False # Added parameter
) -> Dict[str, Any]:
"""
Analyze an image with a text prompt
"""
if not self.is_ready():
raise RuntimeError("Vision model not initialized")
try:
# Convert image bytes to base64 data URI
image_b64 = base64.b64encode(image_data).decode('utf-8')
# Validate image
image = Image.open(io.BytesIO(image_data))
logger.info(f"Processing image: {image.size} | Format: {image.format}")
# Modify prompt if return_json is requested
# Note: For LLaVA/Vision models, it is often safer to append the system instruction
# to the user text rather than a separate system role message.
final_prompt = prompt
if return_json:
final_prompt += (
"\n\nYou are a strict JSON generator. "
"Convert the output into valid JSON format. "
"Output strictly in markdown code blocks like ```json ... ```. "
"Do not add conversational filler."
)
# Create vision message format
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}},
{"type": "text", "text": final_prompt}
]
}
]
logger.info(f"Analyzing image with prompt: {prompt[:50]}... | JSON: {return_json}")
response = self.model.create_chat_completion(
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
content_text = response['choices'][0]['message']['content']
# Logic for return_json
if return_json:
extracted_data = extract_json_from_content(content_text)
return {
"status": "success",
"data": extracted_data,
"image_info": {
"size": list(image.size),
"format": image.format
},
"usage": response.get('usage', {})
}
# Standard return
return {
"status": "success",
"image_info": {
"size": list(image.size),
"format": image.format,
"mode": image.mode
},
"prompt": prompt,
"response": content_text,
"usage": response.get('usage', {})
}
except Exception as e:
logger.error(f"Error analyzing image: {e}")
raise
async def cleanup(self) -> None:
if self.model:
del self.model
self.model = None
if self.chat_handler:
del self.chat_handler
self.chat_handler = None
logger.info("Vision model unloaded")
# Global instance
vision_service = VisionService()