Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| import io | |
| from typing import List | |
| from PIL.Image import Image | |
| from langchain_core.tools import tool | |
| from pinecone import Pinecone | |
| SWIN_API_URL = os.environ.get("SWIN_MODEL_URL", "https://api-inference.huggingface.co/models/Jyo-K/skin_swin") | |
| HF_API_KEY = os.environ.get("HF_API_KEY") | |
| EMBEDDING_API_URL = "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2" | |
| PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY") | |
| PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME") | |
| SWIN_LABELS = [ | |
| '1. Enfeksiyonel', | |
| '2. Ekzama', | |
| '3. Akne', | |
| '4. Pigment', | |
| '5. Benign', | |
| '6. Malign' | |
| ] | |
| _pinecone_client = None | |
| _pinecone_index = None | |
| def get_pinecone_index(): | |
| """Lazily initializes and returns the Pinecone index.""" | |
| global _pinecone_client, _pinecone_index | |
| if _pinecone_index is None: | |
| if not PINECONE_API_KEY or not PINECONE_INDEX_NAME: | |
| raise ValueError("PINECONE_API_KEY or PINECONE_INDEX_NAME not set.") | |
| _pinecone_client = Pinecone(api_key=PINECONE_API_KEY) | |
| _pinecone_index = _pinecone_client.Index(PINECONE_INDEX_NAME) | |
| print("--- Pinecone Index Initialized ---") | |
| return _pinecone_index | |
| def get_embedding_hf(text: str) -> List[float]: | |
| """Gets the embedding for a text query using the HF Inference API.""" | |
| if not HF_API_KEY: | |
| raise ValueError("HF_API_KEY not set. Cannot get embeddings.") | |
| response = requests.post( | |
| EMBEDDING_API_URL, | |
| headers={"Authorization": f"Bearer {HF_API_KEY}"}, | |
| json={"inputs": text, "options": {"wait_for_model": True}} | |
| ) | |
| response.raise_for_status() | |
| return response.json()[0] | |
| def tool_analyze_skin_image(image: Image) -> str: | |
| """ | |
| Analyzes a PIL Image of a skin condition using the Swin Transformer | |
| Inference API and returns the top predicted disease name. | |
| """ | |
| if not HF_API_KEY: | |
| return "Error: Hugging Face API token not found." | |
| headers = {"Authorization": f"Bearer {HF_API_KEY}"} | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_data = buffered.getvalue() | |
| try: | |
| response = requests.post( | |
| SWIN_API_URL, | |
| headers=headers, | |
| data=img_data | |
| ) | |
| response.raise_for_status() | |
| api_output = response.json() | |
| if isinstance(api_output, dict) and 'error' in api_output: | |
| return f"Error from Swin API: {api_output['error']}" | |
| if isinstance(api_output, list) and api_output: | |
| top_prediction = max(api_output, key=lambda x: x['score']) | |
| label_name = top_prediction['label'] | |
| if "LABEL_" in label_name: | |
| try: | |
| idx = int(label_name.split('_')[-1]) | |
| disease_name_with_prefix = SWIN_LABELS[idx] | |
| except (IndexError, ValueError): | |
| return f"Error: Model returned unknown label {label_name}" | |
| else: | |
| disease_name_with_prefix = label_name | |
| disease_name = disease_name_with_prefix.split('. ')[-1] | |
| print(f"Image Analysis Tool: Predicted '{disease_name}'") | |
| return disease_name | |
| else: | |
| return "Error: Invalid API response format from Swin model." | |
| except Exception as e: | |
| print(f"Image Analysis Tool Error: {e}") | |
| return f"Error during Swin API call: {e}" | |
| def tool_fetch_disease_info(disease_name: str) -> dict: | |
| """ | |
| Queries the Pinecone vector database to find symptoms and treatment | |
| information for a given disease name. | |
| """ | |
| try: | |
| index = get_pinecone_index() | |
| except ValueError as e: | |
| return {"error": str(e)} | |
| try: | |
| print(f"Vector DB Tool: Getting embedding for '{disease_name}'") | |
| query_embedding = get_embedding_hf(disease_name) | |
| query_response = index.query( | |
| vector=query_embedding, | |
| top_k=1, | |
| include_metadata=True | |
| ) | |
| if not query_response.get('matches') or query_response['matches'][0]['score'] < 0.5: | |
| return {"error": f"No high-confidence information found for '{disease_name}' in the database."} | |
| metadata = query_response['matches'][0]['metadata'] | |
| symptoms_str = metadata.get("symptoms", "") | |
| symptoms_list = [s.strip() for s in symptoms_str.split(',') if s.strip()] | |
| treatment = metadata.get("treatment", "No treatment information found.") | |
| return { | |
| "disease": metadata.get("disease", disease_name), | |
| "symptoms": symptoms_list, | |
| "treatment": treatment, | |
| "context": metadata.get("text_content", "") | |
| } | |
| except Exception as e: | |
| print(f"Vector DB Tool Error: {e}") | |
| return {"error": f"Error during Pinecone query: {e}"} | |