File size: 4,799 Bytes
76bba0b
 
 
 
 
 
 
 
 
 
 
 
 
 
cc8732a
76bba0b
cc8732a
76bba0b
 
 
 
 
 
 
 
 
 
cc8732a
 
 
 
 
 
76bba0b
 
 
 
 
 
 
 
 
 
 
 
 
cc8732a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76bba0b
cc8732a
 
 
 
 
 
 
 
 
 
76bba0b
 
 
cc8732a
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Model routing and management for fraud detection
"""
import logging
from typing import Optional, Dict, Any
from huggingface_hub import InferenceClient

logger = logging.getLogger(__name__)

class ModelRouter:
    """Routes requests to appropriate models based on task type"""
    
    def __init__(self, hf_token: str):
        self.hf_token = hf_token
        # Updated model names - using more reliable models
        self.models = {
            "analysis": "mistralai/Mistral-7B-Instruct-v0.2",  # More stable model
            "coding": "Qwen/Qwen2.5-Coder-32B-Instruct"
        }
        self.clients = {}
        self._init_clients()
    
    def _init_clients(self):
        """Initialize model clients"""
        for task, model_name in self.models.items():
            try:
                if self.hf_token:
                    # Add timeout for better error handling
                    self.clients[task] = InferenceClient(
                        model=model_name, 
                        token=self.hf_token,
                        timeout=60
                    )
                    logger.info(f"Initialized client for {task}: {model_name}")
                else:
                    logger.warning(f"No HF token provided, {task} model will not be available")
            except Exception as e:
                logger.error(f"Failed to initialize {task} client: {e}")
    
    def run(self, prompt: str, task: str = "analysis", max_tokens: int = 500) -> str:
        """Run inference with specified task model"""
        if task not in self.clients:
            return f"Model for task '{task}' not available. Please check HF_TOKEN."
        
        try:
            client = self.clients[task]
            
            # Try chat completion first (preferred for instruction models)
            try:
                messages = [
                    {"role": "system", "content": "You are a helpful fraud detection analyst AI. Be concise, clear, and practical."},
                    {"role": "user", "content": prompt}
                ]
                
                response = client.chat_completion(
                    messages=messages,
                    max_tokens=max_tokens,
                    temperature=0.7
                )
                
                # Extract text from response
                if hasattr(response, 'choices') and len(response.choices) > 0:
                    return response.choices[0].message.content.strip()
                elif isinstance(response, dict) and 'choices' in response:
                    return response['choices'][0]['message']['content'].strip()
                else:
                    return str(response).strip()
                    
            except Exception as chat_error:
                logger.warning(f"Chat completion failed, trying text generation: {chat_error}")
                
                # Fallback to text generation
                response = client.text_generation(
                    prompt,
                    max_new_tokens=max_tokens,
                    temperature=0.7,
                    do_sample=True,
                    return_full_text=False
                )
                
                # Handle different response types
                if isinstance(response, str):
                    return response.strip()
                elif hasattr(response, 'generated_text'):
                    return response.generated_text.strip()
                else:
                    return str(response).strip()
                
        except Exception as e:
            error_msg = f"Error running {task} model: {str(e)}"
            logger.error(error_msg)
            
            # Return a more helpful fallback message
            return f"""
๐Ÿ” FRAUD ANALYSIS REPORT (Fallback Mode)
๐Ÿ“Š Analysis could not be completed with AI model, but fraud detection algorithms ran successfully.
โš ๏ธ Please check the flagged transactions below for manual review.
๐ŸŽฏ Recommended: Review all flagged transactions and consider implementing additional monitoring.
            """.strip()
    
    def get_available_models(self) -> Dict[str, str]:
        """Get list of available models"""
        return {task: model for task, model in self.models.items() if task in self.clients}
    
    def health_check(self) -> Dict[str, bool]:
        """Check if models are working"""
        status = {}
        for task, client in self.clients.items():
            try:
                # Simple test prompt
                test_response = client.text_generation("Hello", max_new_tokens=10)
                status[task] = True
            except Exception as e:
                logger.error(f"Health check failed for {task}: {e}")
                status[task] = False
        return status