punitub01's picture
Update app.py
2b7fd96 verified
# import gradio as gr
# import torch
# import pandas as pd
# import faiss
# from peft import PeftModel, PeftConfig
# from transformers import AutoTokenizer, AutoModelForCausalLM
# from sentence_transformers import SentenceTransformer
# import os
# def load_components():
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf",token=os.getenv("HF_TOKEN"))
# tokenizer.pad_token = tokenizer.eos_token
# config = PeftConfig.from_pretrained("punitub01/llama2-7b-qlora-finetuned")
# base_model = AutoModelForCausalLM.from_pretrained(
# "meta-llama/Llama-2-7b-chat-hf",
# device_map="cpu",
# torch_dtype=torch.float16, # Changed to float32 for CPU compatibility
# token=os.getenv("HF_TOKEN")
# )
# model = PeftModel.from_pretrained(base_model, "punitub01/llama2-7b-qlora-finetuned")
# encoder = SentenceTransformer('all-MiniLM-L6-v2')
# index = faiss.read_index("diabetes_abstracts.index")
# metadata = pd.read_csv("diabetes_metadata.csv")
# return tokenizer, model, encoder, index, metadata
# # tokenizer, model, encoder, index, metadata = load_components()
# # Load other components (unchanged)
# # encoder = SentenceTransformer('all-MiniLM-L6-v2')
# # index = faiss.read_index("diabetes_abstracts.index")
# # metadata = pd.read_csv("diabetes_metadata.csv")
# # return tokenizer, model, encoder, index, metadata
# tokenizer, model, encoder, index, metadata = load_components()
# chat_history = []
# def summarize_with_llama(text):
# prompt = f"""Summarize this medical information in 1-2 lines:
# {text}
# Concise summary:"""
# inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# with torch.no_grad():
# outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.1)
# return tokenizer.decode(outputs[0], skip_special_tokens=True).split("Concise summary:")[-1].strip()
# def respond(message, history):
# # Semantic search
# query_embed = encoder.encode([message])
# distances, indices = index.search(query_embed, k=3)
# # Build context
# references = [
# f"{metadata.iloc[idx]['title']} (Score: {dist:.2f})"
# for idx, dist in zip(indices[0], distances[0]) if dist >= 0.3
# ]
# context_summary = summarize_with_llama("\n".join(references)) if references else "No clinical references"
# # Generate response
# prompt = f"""Clinical Context: {context_summary}
# Chat History: {history[-2:] if history else 'None'}
# Question: {message}
# Answer:"""
# inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# with torch.no_grad():
# outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.7)
# return tokenizer.decode(outputs[0], skip_special_tokens=True).split("Answer:")[-1].strip()
# # Gradio interface
# gr.ChatInterface(
# respond,
# title="Diabetes Assistant",
# description="Ask questions about diabetes management",
# examples=["What are hypoglycemia symptoms?", "¿Cómo manejar la diabetes tipo 2?"]
# ).launch()
import os
import gradio as gr
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from llama_cpp import Llama
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from huggingface_hub import hf_hub_download
# 1. Download the GGUF model from Hugging Face
model_path = hf_hub_download(
repo_id="punitub01/llama2-7b-finetuned-gguf-chatbot",
filename="model.gguf", # or your specific GGUF filename
local_dir="models",
token=os.getenv("HF_TOKEN")
)
def load_components():
try:
logger.info("Checking environment...")
logger.info(f"Directory contents: {os.listdir('.')}")
logger.info("Loading GGUF model...")
# Initialize Llama model with GGUF file
# 2. Initialize the Llama instance
llm = Llama(
model_path=model_path, # Use the downloaded path
n_ctx=512, # Context window size
n_threads=2, # CPU threads
n_gpu_layers=0 # Use all GPU layers if available (typical value for 7B models)
)
logger.info("Loading sentence transformer and FAISS index...")
encoder = SentenceTransformer('all-MiniLM-L6-v2')
index = faiss.read_index("diabetes_abstracts.index")
metadata = pd.read_csv("diabetes_metadata.csv")
return llm, encoder, index, metadata
except Exception as e:
logger.error(f"Failed to load components: {str(e)}")
raise
def summarize_with_llama(text, llm):
try:
prompt = f"""Summarize this medical information in 1-2 lines:
{text}
Concise summary:"""
output = llm(
prompt,
max_tokens=100,
temperature=0.1,
stop=["\n"],
echo=False
)
return output['choices'][0]['text'].strip()
except Exception as e:
logger.error(f"Error in summarize_with_llama: {str(e)}")
return "Error summarizing context"
def respond(message: str, history: list[dict]) -> str:
try:
logger.info(f"Received message: {message}")
logger.info(f"History: {history}")
# History is in messages format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
messages = history if history else []
messages.append({"role": "user", "content": message})
# Semantic search
query_embed = encoder.encode([message])
distances, indices = index.search(query_embed, k=3)
# Build context
references = [
f"{metadata.iloc[idx]['title']} (Score: {dist:.2f})"
for idx, dist in zip(indices[0], distances[0]) if dist >= 0.3
]
context_summary = summarize_with_llama("\n".join(references), llm) if references else "No clinical references"
# Format chat history
chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages[-2:]]) if len(messages) >= 2 else "None"
# Generate response
prompt = f"""Clinical Context: {context_summary}
Chat History: {chat_history}
Question: {message}
Answer:"""
output = llm(
prompt,
max_tokens=200,
temperature=0.7,
stop=["\n"],
echo=False
)
response = output['choices'][0]['text'].strip()
return response
except Exception as e:
logger.error(f"Error in respond: {str(e)}")
return f"Error: {str(e)}"
try:
llm, encoder, index, metadata = load_components()
except Exception as e:
logger.error(f"Initialization failed: {str(e)}")
raise
# Gradio interface
try:
logger.info("Starting Gradio interface...")
interface = gr.ChatInterface(
fn=respond,
type="messages", # Modern messages format
title="Diabetes Assistant",
description="Ask questions about diabetes management"
)
interface.launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
logger.error(f"Gradio launch failed: {str(e)}")
raise