suusuu93 commited on
Commit
6924870
·
verified ·
1 Parent(s): 685ff14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -21
app.py CHANGED
@@ -1,22 +1,142 @@
 
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
-
5
- # Tải hình chỉ 1 lần duy nhất
6
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
7
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
8
-
9
- # Hàm chat đơn giản
10
- def chatbot(msg):
11
- input_ids = tokenizer.encode(msg + tokenizer.eos_token, return_tensors='pt')
12
- output_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
13
- reply = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
14
- return f"🐱 Dr. Ask: {reply}"
15
-
16
- demo = gr.Interface(fn=chatbot,
17
- inputs=gr.Textbox(label="Bạn hỏi gì nè"),
18
- outputs=gr.Textbox(label="Dr. Ask AI trả lời"),
19
- title="🐱 Dr. Ask AI",
20
- theme="default")
21
-
22
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+
4
+ from langchain_groq import ChatGroq
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.runnables import RunnablePassthrough
8
+
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain_community.document_loaders import PyPDFLoader
11
+ from langchain_text_splitters.sentence_transformers import SentenceTransformersTokenTextSplitter
12
+ from langchain_chroma import Chroma
13
+
14
+
15
+ # ==============================
16
+ # CONFIG
17
+ # ==============================
18
+ os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN", "")
19
+
20
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
21
+
22
+ if not GROQ_API_KEY:
23
+ raise ValueError("GROQ_API_KEY not found in environment variables")
24
+
25
+ DATASET_PATH = "dataset.pdf"
26
+ PERSIST_DIR = "pharma_db"
27
+
28
+ os.makedirs(PERSIST_DIR, exist_ok=True)
29
+
30
+
31
+ # ==============================
32
+ # EMBEDDINGS (FASTER MODEL)
33
+ # ==============================
34
+ embeddings = HuggingFaceEmbeddings(
35
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
36
+ )
37
+
38
+
39
+ # ==============================
40
+ # VECTOR DB
41
+ # ==============================
42
+ db = Chroma(
43
+ persist_directory=PERSIST_DIR,
44
+ embedding_function=embeddings
45
+ )
46
+
47
+
48
+ # ==============================
49
+ # LOAD & INDEX PDF
50
+ # ==============================
51
+ if os.path.exists(DATASET_PATH):
52
+
53
+ # Only index if DB empty
54
+ if len(db.get()["ids"]) == 0:
55
+ print("Indexing PDF...")
56
+
57
+ loader = PyPDFLoader(DATASET_PATH)
58
+ documents = loader.load()
59
+
60
+ splitter = SentenceTransformersTokenTextSplitter(
61
+ chunk_size=500,
62
+ chunk_overlap=50
63
+ )
64
+
65
+ chunks = splitter.split_documents(documents)
66
+ db.add_documents(chunks)
67
+
68
+ print("✅ PDF indexed.")
69
+
70
+ else:
71
+ print("⚠️ PDF not found in repo.")
72
+
73
+
74
+ # ==============================
75
+ # PROMPT
76
+ # ==============================
77
+ prompt = ChatPromptTemplate.from_messages([
78
+ ("system", """You are 'Dr MomAI Assistant', a specialized medical AI expert focused on mom and baby.
79
+ GUIDELINES:
80
+ 1. INTERACTIVE GREETINGS: If the user greets you (e.g., "Hi", "Hello", "Who are you?"), respond politely, introduce yourself as Dr Mom AI Assistant, and explain that you are here to help them understand information.
81
+ 2. CONTEXTUAL ACCURACY: For all medical or factual questions, prioritize the information provided in the 'Context' section below.
82
+ 3. STRICTNESS: If the question is medical in nature but the answer is NOT found in the context, explicitly state something like this: "I'm sorry, but that specific information is not available in my current medical knowledge."
83
+ 4. TONE: Maintain a professional, empathetic, and clinical tone. Use bullet points for complex medical explanations to ensure clarity.
84
+ Context:
85
+ {context}"""),
86
+ ("human", "{question}")
87
+ ])
88
+
89
+ output_parser = StrOutputParser()
90
+
91
+
92
+ def format_docs(docs):
93
+ return "\n\n".join(doc.page_content for doc in docs)
94
+
95
+
96
+ # ==============================
97
+ # RAG QUERY
98
+ # ==============================
99
+ def run_query(question):
100
+
101
+ if not question.strip():
102
+ return "Please enter a question."
103
+
104
+ retriever = db.as_retriever(search_kwargs={"k": 5})
105
+
106
+ llm = ChatGroq(
107
+ model="llama-3.1-8b-instant",
108
+ api_key=GROQ_API_KEY,
109
+ temperature=0
110
+ )
111
+
112
+ rag_chain = (
113
+ {
114
+ "context": retriever | format_docs,
115
+ "question": RunnablePassthrough(),
116
+ }
117
+ | prompt
118
+ | llm
119
+ | output_parser
120
+ )
121
+
122
+ return rag_chain.invoke(question)
123
+
124
+
125
+ # ==============================
126
+ # GRADIO UI
127
+ # ==============================
128
+ interface = gr.Interface(
129
+ fn=run_query,
130
+ inputs=gr.Textbox(
131
+ label="Question",
132
+ placeholder="Ask me something..."
133
+ ),
134
+ outputs=gr.Textbox(
135
+ label="Response",
136
+ lines=10
137
+ ),
138
+ title="Your Assistant",
139
+ description="Ask questions"
140
+ )
141
+
142
+ interface.launch()