Spaces:
Paused
Paused
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "E-Hospital/open-orca-platypus-2-lora-latest", | |
| trust_remote_code=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("Open-Orca/OpenOrca-Platypus2-13B", trust_remote_code=True) | |
| def ask_bot(question): | |
| input_ids = tokenizer.encode(question, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output = model.generate(input_ids, max_length=200, num_return_sequences=1, do_sample=True, top_k=50) | |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| response = generated_text.split("->:")[-1] | |
| return response | |
| import mysql.connector | |
| import re | |
| from datetime import datetime | |
| from typing import Any, List, Mapping, Optional | |
| from langchain.memory import ConversationBufferMemory | |
| from typing import Any, List, Mapping, Optional | |
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |
| from langchain.llms.base import LLM | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from langchain.memory import ConversationSummaryBufferMemory | |
| from langchain.memory import ConversationSummaryMemory | |
| class CustomLLM(LLM): | |
| n: int | |
| # custom_model: llm # Replace with the actual type of your custom model | |
| def _llm_type(self) -> str: | |
| return "custom" | |
| def _call( | |
| self, | |
| prompt: str, | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| if stop is not None: | |
| raise ValueError("stop kwargs are not permitted.") | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output = model.generate(input_ids, max_length=100, num_return_sequences=1, do_sample=True, top_k=50) | |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| response = generated_text.split("->:")[-1] | |
| return response | |
| def _identifying_params(self) -> Mapping[str, Any]: | |
| """Get the identifying parameters.""" | |
| return {"n": self.n} | |
| class DbHandler(): | |
| def __init__(self): | |
| self.db_con = mysql.connector.connect( | |
| host="frwahxxknm9kwy6c.cbetxkdyhwsb.us-east-1.rds.amazonaws.com", | |
| user="j6qbx3bgjysst4jr", | |
| password="mcbsdk2s27ldf37t", | |
| port=3306, | |
| database="nkw2tiuvgv6ufu1z") | |
| self.cursorObject = self.db_con.cursor() | |
| def insert(self, fields, values): | |
| try: | |
| # Convert the lists to comma-separated strings | |
| fields_str = ', '.join(fields) | |
| values_str = ', '.join([f"'{v}'" for v in values]) # Wrap values in single quotes for SQL strings | |
| # Construct the SQL query | |
| query = f"INSERT INTO chatbot_conversation ({fields_str}) VALUES ({values_str})" | |
| self.cursorObject.execute(query) | |
| self.db_con.commit() | |
| return True | |
| except Exception as e: | |
| print(e) | |
| return False | |
| def get_history(self, patient_id): | |
| try: | |
| query = f"SELECT * FROM chatbot_conversation WHERE patient_id = {patient_id} ORDER BY timestamp ASC;" | |
| self.cursorObject.execute(query) | |
| data = self.cursorObject.fetchall() | |
| return data | |
| except Exception as e: | |
| print(e) | |
| return None | |
| def close_db(self): | |
| self.db_con.close() | |
| def get_conversation_history(db, patient_id): | |
| conversations = db.get_history(patient_id) | |
| if conversations: | |
| return conversations[-1][5] | |
| return "" | |
| llm = CustomLLM(n=10) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=['*'], | |
| allow_credentials=True, | |
| allow_methods=['*'], | |
| allow_headers=['*'], | |
| ) | |
| async def root(): | |
| return {'status': 'running'} | |
| def chatbot(patient_id, user_data: dict=None): | |
| user_input = user_data["userObject"]["userInput"].get("message") | |
| db = DbHandler() | |
| try: | |
| history = get_conversation_history(db, patient_id) | |
| memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=200) | |
| prompt = "You are now a medical chatbot, and I am a patient. I will describe my conditions and symptoms and you will give me medical suggestions" | |
| if history: | |
| human_input = prompt + "The following is the patient's previous conversation with you: " + history + "This is the current question: " + user_input + " ->:" | |
| else: | |
| human_input = prompt + user_input + " ->:" | |
| human_text = user_input.replace("'", "") | |
| # response = llm._call(human_input) | |
| response = ask_bot(human_input) | |
| # response = response.replace("'", "") | |
| # memory.save_context({"input": user_input}, {"output": response}) | |
| # summary = memory.load_memory_variables({}) | |
| # ai_text = response.replace("'", "") | |
| # memory.save_context({"input": user_input}, {"output": ai_text}) | |
| # summary = memory.load_memory_variables({}) | |
| # db.insert(("patient_id", "patient_text", "ai_text", "timestamp", "summarized_text"), (patient_id, human_text, ai_text, str(datetime.now()), summary['history'].replace("'", ""))) | |
| db.close_db() | |
| return {"response": response} | |
| finally: | |
| db.close_db() | |
| if __name__=='__main__': | |
| uvicorn.run('main:app', reload=True) |