Muhammadidrees commited on
Commit
8aceec0
·
verified ·
1 Parent(s): 1b79766

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3
+ import torch
4
+ import time
5
+
6
+ # =======================================================
7
+ # Session state to track multi-step questions
8
+ # =======================================================
9
+ session_answers = {}
10
+
11
+ # =======================================================
12
+ # Load Model
13
+ # =======================================================
14
+ model_name = "augtoma/qCammel-13"
15
+
16
+ print("Loading tokenizer and model...")
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+
19
+ if tokenizer.pad_token is None:
20
+ tokenizer.pad_token = tokenizer.eos_token
21
+
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_name,
24
+ device_map="auto",
25
+ torch_dtype=torch.float16,
26
+ trust_remote_code=True,
27
+ low_cpu_mem_usage=True
28
+ )
29
+ model.eval()
30
+
31
+ print("Model loaded successfully!")
32
+ print(f"Device map: {model.hf_device_map}")
33
+ print(f"Model device: {next(model.parameters()).device}")
34
+ print(f"GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
35
+
36
+ # =======================================================
37
+ # Generate Response with token-by-token streaming
38
+ # =======================================================
39
+ def generate_doctor_response(history, session_answers):
40
+ user_message = history[-1]["content"]
41
+
42
+ if not user_message.strip():
43
+ history.append({"role": "assistant", "content": "⚠️ Please describe your symptoms or ask a question."})
44
+ yield history
45
+ return
46
+
47
+ # Build conversation prompt
48
+ prompt = """You are an experienced doctor conducting a medical consultation. Your role is to:
49
+ 1. Ask one follow-up question at a time
50
+ 2. Provide advice or suggestions if possible
51
+ 3. Be conversational, caring, and thorough\n\n"""
52
+
53
+ # Include last 5 exchanges
54
+ recent_history = history[-11:-1] if len(history) > 11 else history[:-1]
55
+ for msg in recent_history:
56
+ role = "Patient" if msg["role"] == "user" else "Doctor"
57
+ content = msg['content'].replace(
58
+ "⚕️ *Note: This is AI-generated information and not a substitute for professional medical advice. Please consult a healthcare provider for proper diagnosis and treatment.*",
59
+ ""
60
+ ).strip()
61
+ prompt += f"{role}: {content}\n"
62
+
63
+ prompt += f"Patient: {user_message}\nDoctor:"
64
+
65
+ # Tokenize
66
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
67
+
68
+ gen_config = GenerationConfig(
69
+ temperature=0.7,
70
+ top_p=0.9,
71
+ do_sample=True,
72
+ max_new_tokens=120,
73
+ pad_token_id=tokenizer.pad_token_id,
74
+ eos_token_id=tokenizer.eos_token_id,
75
+ repetition_penalty=1.2
76
+ )
77
+
78
+ input_length = inputs["input_ids"].shape[1]
79
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
80
+
81
+ with torch.no_grad():
82
+ output_ids = model.generate(
83
+ **inputs,
84
+ generation_config=gen_config
85
+ )
86
+
87
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
88
+
89
+ # Decode and clean response
90
+ generated_ids = output_ids[0][input_length:]
91
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
92
+
93
+ # Stop at hints of patient message
94
+ stop_patterns = [
95
+ "Patient:", "\nPatient", "P:", "How are you", "I am feeling", "Thanks"
96
+ ]
97
+ min_stop_pos = len(response)
98
+ for pattern in stop_patterns:
99
+ pos = response.lower().find(pattern.lower())
100
+ if pos != -1 and pos < min_stop_pos:
101
+ min_stop_pos = pos
102
+ response = response[:min_stop_pos].strip()
103
+
104
+ if response.lower().startswith("doctor:"):
105
+ response = response[7:].strip()
106
+
107
+ if len(response) < 10:
108
+ response = "I understand your concern. Could you please provide more details about your symptoms so I can assist you better?"
109
+
110
+ # Append assistant placeholder for streaming
111
+ history.append({"role": "assistant", "content": ""})
112
+
113
+ # Stream token by token
114
+ for i in range(0, len(response), 4):
115
+ chunk = response[:i+4]
116
+ history[-1]["content"] = chunk + "▌"
117
+ yield history.copy()
118
+ time.sleep(0.015)
119
+
120
+ # Final response with disclaimer
121
+ history[-1]["content"] = response
122
+ yield history
123
+
124
+ # =======================================================
125
+ # Gradio Interface
126
+ # =======================================================
127
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
128
+ gr.Markdown("# 🩺 AI Doctor Chat Assistant")
129
+
130
+ chatbot = gr.Chatbot(
131
+ label="💬 Doctor Consultation",
132
+ type='messages',
133
+ avatar_images=(
134
+ "https://cdn-icons-png.flaticon.com/512/706/706830.png", # Patient
135
+ "https://cdn-icons-png.flaticon.com/512/3774/3774299.png" # Doctor
136
+ ),
137
+ height=500
138
+ )
139
+
140
+ with gr.Row():
141
+ user_input = gr.Textbox(
142
+ placeholder="Type your symptoms or question here...",
143
+ label="🧍 Your Message",
144
+ lines=2,
145
+ scale=4
146
+ )
147
+
148
+ with gr.Row():
149
+ send_btn = gr.Button("💬 Send", variant="primary", scale=1)
150
+ clear_btn = gr.Button("🧹 Clear Chat", scale=1)
151
+
152
+ gr.Examples(
153
+ examples=[
154
+ "I have a fever of 102°F since yesterday",
155
+ "I've been having headaches for the past week",
156
+ "I feel very tired all the time",
157
+ "I have a sore throat and body aches",
158
+ ],
159
+ inputs=user_input,
160
+ label="💡 Example Questions"
161
+ )
162
+
163
+ # Response function
164
+ def respond(message, history):
165
+ global session_answers
166
+ if history is None:
167
+ history = []
168
+ if not message.strip():
169
+ return "", history
170
+ history.append({"role": "user", "content": message})
171
+ for updated_history in generate_doctor_response(history, session_answers):
172
+ yield "", updated_history
173
+
174
+ # Event handlers
175
+ send_btn.click(respond, [user_input, chatbot], [user_input, chatbot])
176
+ user_input.submit(respond, [user_input, chatbot], [user_input, chatbot])
177
+ clear_btn.click(lambda: [], None, chatbot, queue=False)
178
+
179
+ # Launch
180
+ if __name__ == "__main__":
181
+ demo.queue()
182
+ demo.launch(share=True)