duniele commited on
Commit
48dda83
·
verified ·
1 Parent(s): 8de173f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -54
app.py CHANGED
@@ -1,70 +1,123 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
3
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
20
 
21
- messages.extend(history)
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
 
24
 
25
- response = ""
 
 
 
 
 
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
41
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
1
+ import os
2
+ import torch
3
  import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
+ from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
6
+ from langchain_chroma import Chroma
7
+ from typing import Dict, Any, List
8
 
9
+ # --- 1. SETUP & MODEL LOADING ---
10
+ print("⏳ Loading Models...")
11
 
12
+ # Initialize Embeddings (CPU is fine for this)
13
+ embedding_function = HuggingFaceEmbeddings(
14
+ model_name="nomic-ai/nomic-embed-text-v1.5",
15
+ model_kwargs={"trust_remote_code": True, "device": "cpu"}
16
+ )
 
 
 
 
 
 
 
 
17
 
18
+ # Load Vector Database
19
+ # NOTE: Ensure the 'chroma_db' folder is uploaded to the same directory as app.py
20
+ if not os.path.exists("./chroma_db"):
21
+ raise ValueError("❌ Error: 'chroma_db' folder not found! Please upload your vector database.")
22
 
23
+ vector_db = Chroma(persist_directory="./chroma_db", embedding_function=embedding_function)
24
 
25
+ # Load LLM (TinyLlama)
26
+ # We use device_map="auto" to use GPU if available in the Space, otherwise CPU
27
+ model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
28
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
29
+ model = AutoModelForCausalLM.from_pretrained(model_id)
30
 
31
+ # Create HF Pipeline
32
+ pipe = pipeline(
33
+ "text-generation",
34
+ model=model,
35
+ tokenizer=tokenizer,
36
+ max_new_tokens=256,
37
+ repetition_penalty=1.15,
38
+ temperature=0.1,
39
+ do_sample=True
40
+ )
41
 
42
+ llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # --- 2. DEFINE MANUAL QA CHAIN ---
45
+ class ManualQAChain:
46
+ def __init__(self, vector_store: Chroma, llm_pipeline: HuggingFacePipeline):
47
+ self.retriever = vector_store.as_retriever(search_kwargs={"k": 2})
48
+ self.llm = llm_pipeline
49
 
50
+ def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
51
+ query = inputs.get("query", "")
52
 
53
+ # 1. RETRIEVAL
54
+ docs = self.retriever.invoke(query)
55
+ context = "\n\n".join([d.page_content for d in docs])
56
+
57
+ # 2. PROMPT CREATION
58
+ max_context_length = 2000
59
+ prompt = f"""<|system|>
60
+ You are a helpful and accurate medical assistant.
61
+ Use ONLY the following context to answer the user's question.
62
+ If the context does not contain the answer, say: "I cannot find the answer in the provided context."
63
+
64
+ Context:
65
+ {context[:max_context_length]}
66
+ </s>
67
+ <|user|>
68
+ {query}
69
+ </s>
70
+ <|assistant|>
71
  """
72
+ # 3. GENERATION
73
+ response = self.llm.invoke(prompt)
74
+
75
+ # Handle Output format (some versions return list, some string)
76
+ text = response[0]['generated_text'] if isinstance(response, list) else str(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ # Clean output
79
+ if "<|assistant|>" in text:
80
+ final_answer = text.split("<|assistant|>")[-1].strip()
81
+ else:
82
+ final_answer = text.strip()
83
 
84
+ return {"result": final_answer, "source_documents": docs}
85
+
86
+ # Initialize Chain
87
+ qa_chain = ManualQAChain(vector_db, llm)
88
+ print("✅ RAG Pipeline is ready.")
89
+
90
+ # --- 3. GRADIO UI FUNCTION ---
91
+ def medical_rag_chat(message, history):
92
+ if not message:
93
+ return "Please ask a medical question."
94
+ try:
95
+ response = qa_chain.invoke({"query": message})
96
+ answer_text = response['result']
97
+
98
+ # Format Sources
99
+ sources_text = "\n\n---\n**Retrieved Context:**\n"
100
+ if response.get('source_documents'):
101
+ for i, doc in enumerate(response['source_documents']):
102
+ topic = doc.metadata.get('focus_area', 'Medical Protocol')
103
+ snippet = doc.page_content.replace('\n', ' ').strip()
104
+ sources_text += f"**{i+1}. [{topic}]** *\"{snippet[:500]}...\"*\n"
105
+ else:
106
+ sources_text += "(No context found.)"
107
+
108
+ return answer_text + sources_text
109
+ except Exception as e:
110
+ return f"⚠️ Error: {str(e)}"
111
+
112
+ # --- 4. LAUNCH UI ---
113
+ # Note: share=True is NOT needed in HF Spaces
114
+ demo = gr.ChatInterface(
115
+ fn=medical_rag_chat,
116
+ title="Cardio-Oncology RAG Assistant",
117
+ description="TinyLlama-1.1B + MedQuAD RAG",
118
+ examples=["What is (are) BRCA2 hereditary breast and ovarian cancer syndrome ?", "Who is at risk for Heart Failure?"],
119
+ concurrency_limit=2
120
+ )
121
 
122
  if __name__ == "__main__":
123
+ demo.launch()