| """ |
| Custom Handler for QwenStem-7b on Hugging Face Endpoints |
| Handles both text and multimodal (text+image) inputs |
| """ |
|
|
| import torch |
| import base64 |
| import logging |
| from io import BytesIO |
| from typing import Dict, List, Any, Optional |
| from PIL import Image |
| from transformers import AutoProcessor, AutoModelForVision2Seq |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| """ |
| Initialize the model handler for HF Endpoints |
| Args: |
| path: Path to the model directory (provided by HF Endpoints) |
| """ |
| logger.info(f"Initializing model from path: {path}") |
| |
| |
| if torch.cuda.is_available(): |
| self.device = torch.device("cuda") |
| logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") |
| else: |
| self.device = torch.device("cpu") |
| logger.info("Using CPU") |
| |
| try: |
| |
| logger.info("Loading processor...") |
| self.processor = AutoProcessor.from_pretrained( |
| path if path else "analist/QwenStem-7b", |
| trust_remote_code=True |
| ) |
| |
| |
| |
| logger.info("Loading model...") |
| self.model = AutoModelForVision2Seq.from_pretrained( |
| path if path else "analist/QwenStem-7b", |
| trust_remote_code=True, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| low_cpu_mem_usage=True |
| ).to(self.device) |
| |
| |
| self.model.eval() |
| |
| logger.info("Model loaded successfully!") |
| |
| except Exception as e: |
| logger.error(f"Error loading model: {str(e)}") |
| raise |
| |
| |
| self.default_generation_config = { |
| "max_new_tokens": 9192 * 10, |
| "temperature": 0.7, |
| "top_p": 0.9, |
| "do_sample": True, |
| "repetition_penalty": 1.05 |
| } |
| |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Process incoming request for HF Endpoints |
| |
| Args: |
| data: Dictionary containing: |
| - inputs: Text prompt (str) or dict with 'text' and optionally 'image' |
| - parameters: Optional generation parameters (dict) |
| |
| Returns: |
| List with response dictionary |
| """ |
| try: |
| |
| inputs = data.get("inputs", "") |
| parameters = data.get("parameters", {}) |
| |
| |
| logger.info(f"Processing request - Input type: {type(inputs)}") |
| |
| |
| gen_config = {**self.default_generation_config, **parameters} |
| |
| |
| if isinstance(inputs, dict): |
| |
| text = inputs.get("text", "") |
| image_data = inputs.get("image", None) |
| |
| if image_data: |
| logger.info("Processing multimodal input (text + image)") |
| response = self._process_multimodal(text, image_data, gen_config) |
| else: |
| logger.info("Processing text-only input from dict") |
| response = self._process_text(text, gen_config) |
| |
| elif isinstance(inputs, str): |
| |
| logger.info("Processing text-only input") |
| response = self._process_text(inputs, gen_config) |
| |
| else: |
| raise ValueError(f"Unsupported input type: {type(inputs)}") |
| |
| return [{"generated_text": response}] |
| |
| except Exception as e: |
| logger.error(f"Error during inference: {str(e)}") |
| return [{"error": str(e), "error_type": type(e).__name__}] |
| |
| def _process_text(self, text: str, config: dict) -> str: |
| """ |
| Process text-only input |
| """ |
| if not text: |
| raise ValueError("Empty text input") |
| |
| |
| messages = [ |
| {"role": "user", "content": [ |
| {"type": "text", "text": text} |
| ]} |
| ] |
| |
| |
| text_inputs = self.processor.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| text_inputs, |
| max_new_tokens=config.get("max_new_tokens", 9192 * 10), |
| temperature=config.get("temperature", 0.7), |
| top_p=config.get("top_p", 0.9), |
| do_sample=config.get("do_sample", True), |
| repetition_penalty=config.get("repetition_penalty", 1.05), |
| pad_token_id=self.processor.tokenizer.eos_token_id, |
| eos_token_id=self.processor.tokenizer.eos_token_id |
| ) |
| |
| |
| full_response = self.processor.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| if "assistant" in full_response: |
| response = full_response.split("assistant")[-1].strip() |
| else: |
| |
| response = full_response[len(self.processor.decode(text_inputs[0], skip_special_tokens=True)):].strip() |
| |
| return response |
| |
| def _process_multimodal(self, text: str, image_b64: str, config: dict) -> str: |
| """ |
| Process text + image input |
| """ |
| |
| try: |
| if image_b64.startswith('data:image'): |
| |
| image_b64 = image_b64.split(',')[1] |
| |
| image_bytes = base64.b64decode(image_b64) |
| image = Image.open(BytesIO(image_bytes)).convert("RGB") |
| logger.info(f"Image loaded: {image.size}") |
| |
| except Exception as e: |
| logger.error(f"Image decode error: {str(e)}") |
| raise ValueError(f"Failed to decode image: {str(e)}") |
| |
| |
| messages = [ |
| {"role": "user", "content": [ |
| {"type": "text", "text": text if text else "Analyse cette image."}, |
| {"type": "image"} |
| ]} |
| ] |
| |
| |
| prompt = self.processor.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=False |
| ) |
| |
| |
| inputs = self.processor( |
| text=prompt, |
| images=[image], |
| return_tensors="pt" |
| ) |
| |
| |
| inputs = {k: v.to(self.device) if hasattr(v, 'to') else v |
| for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=config.get("max_new_tokens", 9192 * 10), |
| temperature=config.get("temperature", 0.7), |
| top_p=config.get("top_p", 0.9), |
| do_sample=config.get("do_sample", True), |
| repetition_penalty=config.get("repetition_penalty", 1.05), |
| pad_token_id=self.processor.tokenizer.eos_token_id, |
| eos_token_id=self.processor.tokenizer.eos_token_id |
| ) |
| |
| |
| full_response = self.processor.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| if "assistant" in full_response: |
| response = full_response.split("assistant")[-1].strip() |
| else: |
| response = full_response.split(text)[-1].strip() if text in full_response else full_response |
| |
| return response |
| |
| def health(self) -> Dict[str, Any]: |
| """ |
| Health check endpoint for monitoring |
| Returns system and model status |
| """ |
| health_status = { |
| "status": "healthy", |
| "model": { |
| "name": "QwenStem-7b", |
| "type": "Vision-Language Model", |
| "loaded": hasattr(self, 'model') and self.model is not None, |
| "device": str(self.device) if hasattr(self, 'device') else "unknown" |
| }, |
| "system": { |
| "torch_version": torch.__version__, |
| "cuda_available": torch.cuda.is_available(), |
| "gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0 |
| } |
| } |
| |
| |
| if torch.cuda.is_available() and hasattr(self, 'device') and self.device.type == 'cuda': |
| try: |
| gpu_props = torch.cuda.get_device_properties(0) |
| health_status["gpu"] = { |
| "name": gpu_props.name, |
| "memory_total_gb": round(gpu_props.total_memory / (1024**3), 2), |
| "memory_allocated_gb": round(torch.cuda.memory_allocated() / (1024**3), 2), |
| "memory_reserved_gb": round(torch.cuda.memory_reserved() / (1024**3), 2), |
| "utilization_percent": round(torch.cuda.memory_allocated() / gpu_props.total_memory * 100, 2) |
| } |
| except Exception as e: |
| logger.warning(f"Could not get GPU stats: {e}") |
| health_status["gpu"] = {"error": str(e)} |
| |
| |
| if hasattr(self, 'model') and self.model is not None: |
| try: |
| |
| with torch.no_grad(): |
| test_input = self.processor.apply_chat_template( |
| [{"role": "user", "content": [{"type": "text", "text": "test"}]}], |
| tokenize=True, |
| add_generation_prompt=True, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| |
| _ = self.model.generate( |
| test_input, |
| max_new_tokens=1, |
| do_sample=False |
| ) |
| health_status["model"]["responsive"] = True |
| except Exception as e: |
| logger.error(f"Model test failed: {e}") |
| health_status["model"]["responsive"] = False |
| health_status["model"]["error"] = str(e) |
| health_status["status"] = "degraded" |
| |
| return health_status |