File size: 7,338 Bytes
70cd469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965c7c2
3e0863f
 
6b1b0d4
0d98aa3
70cd469
 
 
 
 
0d98aa3
 
 
 
 
 
2b7fd96
 
0d98aa3
3e0863f
70cd469
7deb39a
 
 
0d98aa3
 
 
6b1b0d4
0d98aa3
 
 
2b7fd96
 
 
7deb39a
6b1b0d4
70cd469
 
 
 
3e0863f
0d98aa3
70cd469
 
 
3e0863f
0d98aa3
70cd469
 
 
 
0d98aa3
 
 
 
 
 
 
 
 
 
70cd469
 
 
3e0863f
7deb39a
70cd469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d98aa3
 
 
 
70cd469
 
 
0d98aa3
70cd469
 
 
0d98aa3
 
 
 
 
 
 
70cd469
0d98aa3
70cd469
 
 
 
 
 
0d98aa3
70cd469
 
 
3e0863f
 
70cd469
 
 
 
 
 
7deb39a
70cd469
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# import gradio as gr
# import torch
# import pandas as pd
# import faiss
# from peft import PeftModel, PeftConfig
# from transformers import AutoTokenizer, AutoModelForCausalLM
# from sentence_transformers import SentenceTransformer
# import os

# def load_components():
#     tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf",token=os.getenv("HF_TOKEN"))
#     tokenizer.pad_token = tokenizer.eos_token

#     config = PeftConfig.from_pretrained("punitub01/llama2-7b-qlora-finetuned")
#     base_model = AutoModelForCausalLM.from_pretrained(
#         "meta-llama/Llama-2-7b-chat-hf",
#         device_map="cpu",
#         torch_dtype=torch.float16,  # Changed to float32 for CPU compatibility
#         token=os.getenv("HF_TOKEN")
#     )
#     model = PeftModel.from_pretrained(base_model, "punitub01/llama2-7b-qlora-finetuned")

#     encoder = SentenceTransformer('all-MiniLM-L6-v2')
#     index = faiss.read_index("diabetes_abstracts.index")
#     metadata = pd.read_csv("diabetes_metadata.csv")
#     return tokenizer, model, encoder, index, metadata

# # tokenizer, model, encoder, index, metadata = load_components()

#     # Load other components (unchanged)
#     # encoder = SentenceTransformer('all-MiniLM-L6-v2')
#     # index = faiss.read_index("diabetes_abstracts.index")
#     # metadata = pd.read_csv("diabetes_metadata.csv")
#     # return tokenizer, model, encoder, index, metadata

# tokenizer, model, encoder, index, metadata = load_components()
# chat_history = []

# def summarize_with_llama(text):
#     prompt = f"""Summarize this medical information in 1-2 lines:
#     {text}
#     Concise summary:"""
#     inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
#     with torch.no_grad():
#         outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.1)
#     return tokenizer.decode(outputs[0], skip_special_tokens=True).split("Concise summary:")[-1].strip()

# def respond(message, history):
#     # Semantic search
#     query_embed = encoder.encode([message])
#     distances, indices = index.search(query_embed, k=3)
    
#     # Build context
#     references = [
#         f"{metadata.iloc[idx]['title']} (Score: {dist:.2f})" 
#         for idx, dist in zip(indices[0], distances[0]) if dist >= 0.3
#     ]
#     context_summary = summarize_with_llama("\n".join(references)) if references else "No clinical references"
    
#     # Generate response
#     prompt = f"""Clinical Context: {context_summary}
#     Chat History: {history[-2:] if history else 'None'}
#     Question: {message}
#     Answer:"""
    
#     inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
#     with torch.no_grad():
#         outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.7)
    
#     return tokenizer.decode(outputs[0], skip_special_tokens=True).split("Answer:")[-1].strip()

# # Gradio interface
# gr.ChatInterface(
#     respond,
#     title="Diabetes Assistant",
#     description="Ask questions about diabetes management",
#     examples=["What are hypoglycemia symptoms?", "¿Cómo manejar la diabetes tipo 2?"]
# ).launch()
import os
import gradio as gr
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from llama_cpp import Llama
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from huggingface_hub import hf_hub_download

# 1. Download the GGUF model from Hugging Face
model_path = hf_hub_download(
    repo_id="punitub01/llama2-7b-finetuned-gguf-chatbot",
    filename="model.gguf",  # or your specific GGUF filename
    local_dir="models",
    token=os.getenv("HF_TOKEN")
)
def load_components():
    try:
        logger.info("Checking environment...")
        logger.info(f"Directory contents: {os.listdir('.')}")
        
        logger.info("Loading GGUF model...")
        # Initialize Llama model with GGUF file


        # 2. Initialize the Llama instance
        llm = Llama(
            model_path=model_path,  # Use the downloaded path
            n_ctx=512,            # Context window size
            n_threads=2,           # CPU threads
            n_gpu_layers=0        # Use all GPU layers if available (typical value for 7B models)
        )

        logger.info("Loading sentence transformer and FAISS index...")
        encoder = SentenceTransformer('all-MiniLM-L6-v2')
        index = faiss.read_index("diabetes_abstracts.index")
        metadata = pd.read_csv("diabetes_metadata.csv")

        return llm, encoder, index, metadata
    except Exception as e:
        logger.error(f"Failed to load components: {str(e)}")
        raise

def summarize_with_llama(text, llm):
    try:
        prompt = f"""Summarize this medical information in 1-2 lines:
        {text}
        Concise summary:"""
        
        output = llm(
            prompt,
            max_tokens=100,
            temperature=0.1,
            stop=["\n"],
            echo=False
        )
        
        return output['choices'][0]['text'].strip()
    except Exception as e:
        logger.error(f"Error in summarize_with_llama: {str(e)}")
        return "Error summarizing context"

def respond(message: str, history: list[dict]) -> str:
    try:
        logger.info(f"Received message: {message}")
        logger.info(f"History: {history}")
        
        # History is in messages format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
        messages = history if history else []
        messages.append({"role": "user", "content": message})

        # Semantic search
        query_embed = encoder.encode([message])
        distances, indices = index.search(query_embed, k=3)
        
        # Build context
        references = [
            f"{metadata.iloc[idx]['title']} (Score: {dist:.2f})" 
            for idx, dist in zip(indices[0], distances[0]) if dist >= 0.3
        ]
        context_summary = summarize_with_llama("\n".join(references), llm) if references else "No clinical references"
        
        # Format chat history
        chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages[-2:]]) if len(messages) >= 2 else "None"
        
        # Generate response
        prompt = f"""Clinical Context: {context_summary}
        Chat History: {chat_history}
        Question: {message}
        Answer:"""
        
        output = llm(
            prompt,
            max_tokens=200,
            temperature=0.7,
            stop=["\n"],
            echo=False
        )
        
        response = output['choices'][0]['text'].strip()
        return response
    except Exception as e:
        logger.error(f"Error in respond: {str(e)}")
        return f"Error: {str(e)}"

try:
    llm, encoder, index, metadata = load_components()
except Exception as e:
    logger.error(f"Initialization failed: {str(e)}")
    raise

# Gradio interface
try:
    logger.info("Starting Gradio interface...")
    interface = gr.ChatInterface(
        fn=respond,
        type="messages",  # Modern messages format
        title="Diabetes Assistant",
        description="Ask questions about diabetes management"
    )
    interface.launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
    logger.error(f"Gradio launch failed: {str(e)}")
    raise