Mr-HASSAN commited on
Commit
f06c4ab
·
verified ·
1 Parent(s): c7d192d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -277
app.py CHANGED
@@ -1,286 +1,106 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- import re
5
-
6
- class HuatuoMedicalAI:
7
- def __init__(self):
8
- self.model = None
9
- self.tokenizer = None
10
- self.loaded = False
11
- self.conversation_history = []
12
- self.question_count = 0
13
- self.max_questions = 3
14
-
15
- def load_model(self):
16
- """Load HuatuoGPT model"""
17
- if self.loaded:
18
- return True
19
-
20
- try:
21
- print("🔄 Loading HuatuoGPT-7B Medical AI...")
22
- self.tokenizer = AutoTokenizer.from_pretrained(
23
- "FreedomIntelligence/HuatuoGPT-7B",
24
- trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
- self.model = AutoModelForCausalLM.from_pretrained(
27
- "FreedomIntelligence/HuatuoGPT-7B",
28
- torch_dtype=torch.float16,
29
- device_map="auto",
30
- trust_remote_code=True
 
 
 
 
 
 
31
  )
32
- self.loaded = True
33
- print("✅ HuatuoGPT-7B Medical AI Ready!")
34
- return True
35
- except Exception as e:
36
- print(f"❌ Error loading HuatuoGPT: {e}")
37
- return False
38
-
39
- def needs_clarification(self, message):
40
- """Check if patient message needs clarification"""
41
- message = message.strip().lower()
42
-
43
- # Very short or unclear messages need clarification
44
- if len(message) < 3:
45
- return True
46
 
47
- if len(message.split()) <= 2 and len(message) < 15:
48
- vague_terms = ['help', 'pain', 'sick', 'problem', 'issue', '?']
49
- if any(term in message for term in vague_terms):
50
- return True
51
-
52
- return False
53
-
54
- def generate_medical_question(self, user_message):
55
- """Generate 1-line medical question using HuatuoGPT"""
56
- try:
57
- # Medical consultation prompt for HuatuoGPT
58
- prompt = f"""As a medical doctor, the patient says: "{user_message}"
59
-
60
- What is the most important single question to ask for better diagnosis?
61
- Keep it to one line only.
62
-
63
- Question:"""
64
-
65
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
66
-
67
- with torch.no_grad():
68
- outputs = self.model.generate(
69
- inputs.input_ids,
70
- max_new_tokens=25, # Short for 1-line questions
71
- temperature=0.7,
72
- do_sample=True,
73
- top_p=0.9,
74
- repetition_penalty=1.1,
75
- pad_token_id=self.tokenizer.eos_token_id
76
- )
77
-
78
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
79
- question = response.split("Question:")[-1].strip()
80
-
81
- # Clean and ensure 1-line format
82
- question = re.sub(r'\n+', ' ', question) # Remove newlines
83
- question = question.split('.')[0] # Take first sentence only
84
- if not question.endswith('?'):
85
- question += '?'
86
-
87
- # Force single line and reasonable length
88
- question = question.replace('\n', ' ').strip()
89
- if len(question.split()) > 12:
90
- question = ' '.join(question.split()[:12]) + '?'
91
-
92
- return question
93
-
94
- except Exception as e:
95
- # Fallback questions
96
- fallback_questions = [
97
- "How long have you had these symptoms?",
98
- "Where exactly is the pain located?",
99
- "Can you rate the severity from 1-10?",
100
- "Any other symptoms you're experiencing?"
101
- ]
102
- import random
103
- return random.choice(fallback_questions)
104
-
105
- def get_final_medical_report(self):
106
- """Generate final medical report after 3 questions"""
107
- context = "\n".join(self.conversation_history)
108
 
109
- prompt = f"""Based on this medical conversation:
110
-
111
- {context}
112
-
113
- As Doctor HuatuoGPT, provide a comprehensive medical assessment including:
114
- 1. Possible diagnosis based on symptoms
115
- 2. Immediate self-care recommendations
116
- 3. When to seek urgent medical attention
117
- 4. General health advice
118
-
119
- Keep the response professional, clear, and helpful.
120
-
121
- Medical Assessment:"""
122
 
123
- try:
124
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
125
-
126
- with torch.no_grad():
127
- outputs = self.model.generate(
128
- inputs.input_ids,
129
- max_new_tokens=400,
130
- temperature=0.3,
131
- do_sample=True,
132
- top_p=0.8,
133
- repetition_penalty=1.1,
134
- pad_token_id=self.tokenizer.eos_token_id
135
- )
136
-
137
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
138
- assessment = response.split("Medical Assessment:")[-1].strip()
139
-
140
- # Add standard medical disclaimer
141
- assessment += "\n\n⚠️ **Medical Disclaimer**: This is AI-generated advice. Please consult a healthcare professional for proper diagnosis and treatment."
142
-
143
- return assessment
144
-
145
- except Exception as e:
146
- return "Based on our conversation, I recommend consulting a healthcare professional for proper medical evaluation and treatment."
147
-
148
- def process_patient_message(self, message):
149
- """Main function to handle patient messages"""
150
- if not self.loaded:
151
- success = self.load_model()
152
- if not success:
153
- return "🔄 Medical AI is loading... Please wait.", 0, 3
154
-
155
- # Add patient message to conversation
156
- self.conversation_history.append(f"Patient: {message}")
157
- self.question_count += 1
158
 
159
- # Check if we need more clarification (max 3 questions)
160
- if self.needs_clarification(message) and self.question_count < self.max_questions:
161
- question = self.generate_medical_question(message)
162
- self.conversation_history.append(f"Doctor: {question}")
163
- return question, self.question_count, self.max_questions
164
- else:
165
- # Generate final medical report
166
- final_report = self.get_final_medical_report()
167
- self.conversation_history.append(f"Doctor: {final_report}")
168
-
169
- # Reset for next conversation
170
- self.question_count = 0
171
- self.conversation_history = []
172
-
173
- return final_report, 0, self.max_questions
174
-
175
- # Initialize the medical AI
176
- medical_ai = HuatuoMedicalAI()
177
-
178
- def chat_interface(message, chat_history, question_count):
179
- """Gradio chat interface"""
180
- if not message.strip():
181
- return "", chat_history, question_count
182
-
183
- # Process patient message
184
- response, new_count, max_questions = medical_ai.process_patient_message(message)
185
-
186
- # Add to chat history
187
- chat_history.append((message, response))
188
-
189
- return "", chat_history, new_count
190
-
191
- def clear_conversation():
192
- """Clear conversation and reset counters"""
193
- medical_ai.conversation_history = []
194
- medical_ai.question_count = 0
195
- return [], 0
196
-
197
- # FIXED: Removed theme parameter
198
- with gr.Blocks(title="HuatuoGPT Medical Assistant") as demo:
199
- gr.Markdown("""
200
- # 🩺 HuatuoGPT Medical Assistant
201
- **AI-Powered Medical Consultation • 3 Questions Max • Professional Medical Advice**
202
-
203
- *Patient describes symptoms → AI asks clarifying questions → Final medical assessment*
204
- """)
205
-
206
- # Question counter
207
- question_counter = gr.Textbox(
208
- label="Clarification Questions",
209
- value="0/3",
210
- interactive=False,
211
- max_lines=1
212
- )
213
-
214
- # Chat interface
215
- chatbot = gr.Chatbot(
216
- label="Medical Consultation",
217
- height=500
218
- )
219
-
220
- with gr.Row():
221
- # Patient input - can be any length
222
- msg = gr.Textbox(
223
- label="Describe Your Symptoms",
224
- placeholder="Example: headache for 2 days with sensitivity to light...",
225
- lines=3,
226
- scale=4
227
- )
228
- send_btn = gr.Button("🚀 Send to Doctor", scale=1, variant="primary")
229
-
230
- with gr.Row():
231
- clear_btn = gr.Button("🔄 New Consultation")
232
- status = gr.Textbox(
233
- label="Status",
234
- value="HuatuoGPT-7B Medical AI - Ready for Consultation",
235
- interactive=False,
236
- max_lines=2
237
- )
238
-
239
- # Hidden state for question count
240
- current_count = gr.State(0)
241
-
242
- def update_counter(question_count, max_questions=3):
243
- """Update question counter display"""
244
- return f"{question_count}/{max_questions}"
245
-
246
- def respond(message, chat_history, question_count):
247
- """Handle user response"""
248
- if not message.strip():
249
- return message, chat_history, question_count
250
 
251
- response, new_count, max_questions = medical_ai.process_patient_message(message)
252
- chat_history.append((message, response))
253
 
254
- return "", chat_history, new_count
255
-
256
- # Connect all components
257
- msg.submit(
258
- respond,
259
- [msg, chatbot, current_count],
260
- [msg, chatbot, current_count]
261
- ).then(
262
- update_counter,
263
- [current_count],
264
- [question_counter]
265
- )
266
-
267
- send_btn.click(
268
- respond,
269
- [msg, chatbot, current_count],
270
- [msg, chatbot, current_count]
271
- ).then(
272
- update_counter,
273
- [current_count],
274
- [question_counter]
275
- )
276
-
277
- clear_btn.click(
278
- clear_conversation,
279
- outputs=[chatbot, current_count]
280
- ).then(
281
- lambda: ("", "0/3"),
282
- outputs=[msg, question_counter]
283
- )
284
 
285
- if __name__ == "__main__":
286
- demo.launch(share=False)
 
1
+ # app.py
2
+ from flask import Flask, request, jsonify, render_template
3
+ from flask_cors import CORS
4
+ import base64
5
+ import tempfile
6
+ import os
7
+
8
+ app = Flask(__name__)
9
+ CORS(app)
10
+
11
+ # Initialize components
12
+ medical_agent = MedicalAgent()
13
+ sign_translator = SignLanguageTranslator()
14
+ sign_generator = SignLanguageGenerator()
15
+ speech_processor = SpeechProcessor()
16
+
17
+ # Add sample medical knowledge
18
+ medical_knowledge = [
19
+ "Headache can be caused by tension, migraine, or sinus issues",
20
+ "Common headache symptoms include throbbing pain, sensitivity to light",
21
+ "Headache duration and location help diagnose the type",
22
+ "Migraine often includes nausea and light sensitivity",
23
+ "Tension headaches typically cause band-like pressure around head"
24
+ ]
25
+ medical_agent.rag.add_medical_knowledge(medical_knowledge)
26
+
27
+ @app.route('/')
28
+ def index():
29
+ return render_template('index.html')
30
+
31
+ @app.route('/api/process_sign_language', methods=['POST'])
32
+ def process_sign_language():
33
+ """Process sign language video and return agent response"""
34
+ try:
35
+ video_data = request.json['video_data'] # Base64 encoded video frame
36
+ frame = decode_video_frame(video_data)
37
+
38
+ # Convert sign language to text
39
+ patient_text = sign_translator.process_video_frame(frame)
40
+
41
+ # Process with medical agent
42
+ agent_response = medical_agent.process_patient_input(patient_text)
43
+
44
+ if agent_response['type'] == 'question':
45
+ # Generate sign language for the question
46
+ sign_animation = sign_generator.text_to_sign_animation(
47
+ agent_response['content']
48
  )
49
+ return jsonify({
50
+ 'type': 'question',
51
+ 'text': agent_response['content'],
52
+ 'sign_animation': sign_animation,
53
+ 'question_count': agent_response['question_count']
54
+ })
55
+ else:
56
+ # Send summary to doctor via TTS
57
+ tts_audio = speech_processor.text_to_speech(
58
+ agent_response['content'],
59
+ "summary.wav"
60
  )
61
+ return jsonify({
62
+ 'type': 'summary',
63
+ 'text': agent_response['content'],
64
+ 'audio': tts_audio
65
+ })
 
 
 
 
 
 
 
 
 
66
 
67
+ except Exception as e:
68
+ return jsonify({'error': str(e)}), 500
69
+
70
+ @app.route('/api/process_doctor_audio', methods=['POST'])
71
+ def process_doctor_audio():
72
+ """Process doctor's audio question"""
73
+ try:
74
+ audio_data = request.json['audio_data']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ # Save audio temporarily
77
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as f:
78
+ f.write(base64.b64decode(audio_data))
79
+ audio_path = f.name
 
 
 
 
 
 
 
 
 
80
 
81
+ # Convert speech to text
82
+ doctor_text = speech_processor.speech_to_text(audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # Process with agent
85
+ patient_question = medical_agent.process_doctor_question(doctor_text)
86
+
87
+ # Generate sign language
88
+ sign_animation = sign_generator.text_to_sign_animation(patient_question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ os.unlink(audio_path) # Clean up
 
91
 
92
+ return jsonify({
93
+ 'question': patient_question,
94
+ 'sign_animation': sign_animation
95
+ })
96
+
97
+ except Exception as e:
98
+ return jsonify({'error': str(e)}), 500
99
+
100
+ def decode_video_frame(video_data):
101
+ """Decode base64 video frame"""
102
+ # Implementation depends on your frontend format
103
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ if __name__ == '__main__':
106
+ app.run(host='0.0.0.0', port=5000, debug=True)