File size: 4,928 Bytes
c5ad64e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import requests
import io
from typing import List
from PIL.Image import Image
from langchain_core.tools import tool
from pinecone import Pinecone

SWIN_API_URL = os.environ.get("SWIN_MODEL_URL", "https://api-inference.huggingface.co/models/Jyo-K/skin_swin")
HF_API_KEY = os.environ.get("HF_API_KEY")
EMBEDDING_API_URL = "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2"
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")

SWIN_LABELS = [
    '1. Enfeksiyonel', 
    '2. Ekzama', 
    '3. Akne', 
    '4. Pigment', 
    '5. Benign', 
    '6. Malign'
]


_pinecone_client = None
_pinecone_index = None

def get_pinecone_index():
    """Lazily initializes and returns the Pinecone index."""
    global _pinecone_client, _pinecone_index
    if _pinecone_index is None:
        if not PINECONE_API_KEY or not PINECONE_INDEX_NAME:
            raise ValueError("PINECONE_API_KEY or PINECONE_INDEX_NAME not set.")
        
        _pinecone_client = Pinecone(api_key=PINECONE_API_KEY)
        _pinecone_index = _pinecone_client.Index(PINECONE_INDEX_NAME)
        print("--- Pinecone Index Initialized ---")
    return _pinecone_index

def get_embedding_hf(text: str) -> List[float]:
    """Gets the embedding for a text query using the HF Inference API."""
    if not HF_API_KEY:
        raise ValueError("HF_API_KEY not set. Cannot get embeddings.")
    
    response = requests.post(
        EMBEDDING_API_URL,
        headers={"Authorization": f"Bearer {HF_API_KEY}"},
        json={"inputs": text, "options": {"wait_for_model": True}}
    )
    response.raise_for_status()
    return response.json()[0]

@tool
def tool_analyze_skin_image(image: Image) -> str:
    """
    Analyzes a PIL Image of a skin condition using the Swin Transformer
    Inference API and returns the top predicted disease name.
    """
    if not HF_API_KEY:
        return "Error: Hugging Face API token not found."

    headers = {"Authorization": f"Bearer {HF_API_KEY}"}
    
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    img_data = buffered.getvalue()

    try:
        response = requests.post(
            SWIN_API_URL, 
            headers=headers, 
            data=img_data
        )
        response.raise_for_status()
        api_output = response.json()
        
        if isinstance(api_output, dict) and 'error' in api_output:
             return f"Error from Swin API: {api_output['error']}"

        if isinstance(api_output, list) and api_output:
            top_prediction = max(api_output, key=lambda x: x['score'])
            
            label_name = top_prediction['label']
            if "LABEL_" in label_name:
                try:
                    idx = int(label_name.split('_')[-1])
                    disease_name_with_prefix = SWIN_LABELS[idx]
                except (IndexError, ValueError):
                    return f"Error: Model returned unknown label {label_name}"
            else:
                disease_name_with_prefix = label_name 

            disease_name = disease_name_with_prefix.split('. ')[-1]
            print(f"Image Analysis Tool: Predicted '{disease_name}'")
            return disease_name
        else:
            return "Error: Invalid API response format from Swin model."
            
    except Exception as e:
        print(f"Image Analysis Tool Error: {e}")
        return f"Error during Swin API call: {e}"

@tool
def tool_fetch_disease_info(disease_name: str) -> dict:
    """
    Queries the Pinecone vector database to find symptoms and treatment
    information for a given disease name.
    """
    try:
        index = get_pinecone_index() 
    except ValueError as e:
        return {"error": str(e)}

    try:
        print(f"Vector DB Tool: Getting embedding for '{disease_name}'")
        query_embedding = get_embedding_hf(disease_name)
        
        query_response = index.query(
            vector=query_embedding,
            top_k=1,
            include_metadata=True
        )
        
        if not query_response.get('matches') or query_response['matches'][0]['score'] < 0.5:
            return {"error": f"No high-confidence information found for '{disease_name}' in the database."}

        metadata = query_response['matches'][0]['metadata']
        
        symptoms_str = metadata.get("symptoms", "")
        symptoms_list = [s.strip() for s in symptoms_str.split(',') if s.strip()]
        treatment = metadata.get("treatment", "No treatment information found.")
        
        return {
            "disease": metadata.get("disease", disease_name),
            "symptoms": symptoms_list,
            "treatment": treatment,
            "context": metadata.get("text_content", "")
        }
    except Exception as e:
        print(f"Vector DB Tool Error: {e}")
        return {"error": f"Error during Pinecone query: {e}"}