Spaces:
Build error
Build error
| import faiss | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer, pipeline | |
| import requests | |
| from bs4 import BeautifulSoup | |
| import os | |
| import gradio as gr | |
| import asyncio # Import asyncio for asynchronous processing | |
| # Step 1: Define PromptTemplate class using LangChain's format | |
| class PromptTemplate: | |
| def __init__(self, template): | |
| self.template = template | |
| def format(self, **kwargs): | |
| formatted_text = self.template | |
| for key, value in kwargs.items(): | |
| formatted_text = formatted_text.replace("{" + key + "}", str(value)) | |
| return formatted_text | |
| # Step 2: Load embedding model and tokenizer | |
| embedding_model_name = "ls-da3m0ns/bge_large_medical" | |
| embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name) | |
| embedding_model = AutoModel.from_pretrained(embedding_model_name) | |
| embedding_model.eval() # Set model to evaluation mode | |
| # Move the embedding model to GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| embedding_model.to(device) | |
| # Step 3: Load Faiss index | |
| index_file = "faiss_index.index" | |
| if os.path.exists(index_file): | |
| index = faiss.read_index(index_file) | |
| assert isinstance(index, faiss.IndexFlat), "Expected Faiss IndexFlat type" | |
| assert index.d == 1024, f"Expected index dimension 1024, but got {index.d}" | |
| else: | |
| raise ValueError(f"Faiss index file '{index_file}' not found.") | |
| # Step 4: Prepare URLs | |
| urls_file = "crawled_urls.txt" | |
| if os.path.exists(urls_file): | |
| with open(urls_file, "r") as f: | |
| urls = [line.strip() for line in f] | |
| else: | |
| raise ValueError(f"URLs file '{urls_file}' not found.") | |
| # Step 5: Check if sample embeddings file exists, if not create it | |
| sample_embeddings_file = "sample_embeddings.npy" | |
| if not os.path.exists(sample_embeddings_file): | |
| print("Sample embeddings file not found, creating new sample embeddings...") | |
| # Generate sample data to fit PCA | |
| sample_texts = [ | |
| "medical diagnosis", | |
| "healthcare treatment", | |
| "patient care", | |
| "clinical research", | |
| "disease prevention" | |
| ] | |
| sample_embeddings = [] | |
| for text in sample_texts: | |
| inputs = embedding_tokenizer(text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = embedding_model(**inputs) | |
| embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
| sample_embeddings.append(embedding) | |
| sample_embeddings = np.vstack(sample_embeddings) | |
| np.save(sample_embeddings_file, sample_embeddings) | |
| else: | |
| sample_embeddings = np.load(sample_embeddings_file) | |
| # Step 6: Define function for similarity search | |
| def search_similar(query_text, top_k=3): | |
| inputs = embedding_tokenizer(query_text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = embedding_model(**inputs) | |
| query_embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
| query_embedding = query_embedding / np.linalg.norm(query_embedding) | |
| query_embedding = query_embedding.reshape(1, -1).astype(np.float32) | |
| _, idx = index.search(query_embedding, top_k) | |
| results = [] | |
| for i in range(top_k): | |
| key = int(idx[0][i]) | |
| results.append(urls[key]) # Return URLs only for simplicity | |
| return results | |
| # Step 7: Function to extract content from URLs | |
| def extract_content(url): | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.content, 'html.parser') | |
| # Example: Extracting relevant content based on query | |
| paragraphs = soup.find_all('p') | |
| relevant_content = "" | |
| for para in paragraphs: | |
| relevant_content += para.get_text().strip() | |
| return relevant_content.strip() # Return relevant content as a single string | |
| except requests.RequestException as e: | |
| print(f"Error fetching content from {url}: {e}") | |
| return "" | |
| # Step 8: Use the LangChain text generation pipeline for generating answers | |
| generation_model_name = "microsoft/Phi-3-mini-4k-instruct" | |
| # Use CPU or change to device=0 for GPU (depending on your setup) | |
| text_generator = pipeline("text-generation", model=generation_model_name, device=-1) | |
| # Step 9: Function to generate answer based on query and content | |
| async def generate_answer(query, contents): | |
| answers = [] | |
| prompt_template = PromptTemplate(""" | |
| ### Medical Assistant Context ### | |
| As a helpful medical assistant, I'm here to assist you with your query. | |
| ### Medical Query ### | |
| Query: {query} | |
| ### Explanation ### | |
| {generated_text} | |
| ### Revised Response ### | |
| Response: {generated_text} | |
| """) | |
| batch_prompts = [] | |
| for content in contents: | |
| if content: | |
| prompt = prompt_template.format(query=query, content=content, generated_text="") | |
| batch_prompts.append(prompt) | |
| if not batch_prompts: | |
| return ["No content available to generate an answer."] * len(contents) | |
| # Generate responses in batch asynchronously | |
| generated_texts = await asyncio.gather(*[loop.run_in_executor(None, lambda: text_generator(prompt, max_new_tokens=200, num_return_sequences=1, truncation=True)) for prompt in batch_prompts]) | |
| for i, generated_text in enumerate(generated_texts): | |
| if generated_text and isinstance(generated_text, list) and len(generated_text) > 0: | |
| response = generated_text[0]["generated_text"] | |
| response_start = response.find("Response:") + len("Response:") | |
| answers.append(response[response_start:].strip()) | |
| else: | |
| answers.append("No AI-generated text found.") | |
| return answers | |
| # Gradio interface | |
| def process_query(query): | |
| top_results = search_similar(query, top_k=3) | |
| if top_results: | |
| content = extract_content(top_results[0]) | |
| answer = asyncio.run(generate_answer(query, [content]))[0] | |
| response = f"Rank 1: URL - {top_results[0]}\n" | |
| response += f"Generated Answer:\n{answer}\n" | |
| similar_urls = "\n".join(top_results[1:]) # The second and third URLs as similar URLs | |
| return response, similar_urls | |
| else: | |
| return "No results found.", "No similar URLs found." | |
| demo = gr.Interface( | |
| fn=process_query, | |
| inputs=gr.Textbox(label="Enter your query"), | |
| outputs=[ | |
| gr.Textbox(label="Generated Answer"), | |
| gr.Textbox(label="Similar URLs") | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| loop = asyncio.get_event_loop() | |
| demo.launch() |