Spaces:
Sleeping
Sleeping
| """ | |
| API Server for MedLLaMA2 Medical Chatbot | |
| This file provides REST API endpoints that can be used by external applications | |
| while the main app.py provides the Gradio interface. | |
| """ | |
| import os | |
| import threading | |
| from flask import Flask, request, jsonify, Response | |
| from flask_cors import CORS | |
| import json | |
| import time | |
| import re | |
| # Import the model and functions from the main app | |
| from app import load_model, generate_response, get_model_info | |
| from config import GENERATION_DEFAULTS | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| # Initialize model in a separate thread | |
| def init_model(): | |
| print("π Loading model in API server...") | |
| load_model() | |
| print("β Model loaded in API server") | |
| # Start model loading | |
| model_thread = threading.Thread(target=init_model) | |
| model_thread.start() | |
| def health_check(): | |
| """Health check endpoint""" | |
| return jsonify({ | |
| 'status': 'ok', | |
| 'model_loaded': get_model_info() != "No model loaded", | |
| 'model_info': get_model_info(), | |
| 'timestamp': time.time() | |
| }) | |
| def chat_endpoint(): | |
| """Main chat endpoint for medical questions""" | |
| try: | |
| data = request.get_json() | |
| if not data or 'message' not in data: | |
| return jsonify({'error': 'No message provided'}), 400 | |
| message = data['message'].strip() | |
| if not message: | |
| return jsonify({'error': 'Empty message'}), 400 | |
| # Get optional parameters | |
| max_tokens = data.get('max_tokens', GENERATION_DEFAULTS['max_new_tokens']) | |
| temperature = data.get('temperature', GENERATION_DEFAULTS['temperature']) | |
| top_p = data.get('top_p', GENERATION_DEFAULTS['top_p']) | |
| # Check for non-medical topics | |
| non_medical_patterns = [ | |
| r'\b(java|javascript|python|c\+\+|c#|programming|coding|computer|software)\b', | |
| r'\b(cook|recipe|food recipe|baking)\b', | |
| r'\b(math problem|finance|stock market|weather|movie|book|travel)\b' | |
| ] | |
| is_non_medical = any(re.search(pattern, message, re.IGNORECASE) for pattern in non_medical_patterns) | |
| # Medical exceptions | |
| medical_exceptions = [ | |
| r'medical (history|coding|program|software|algorithm)', | |
| r'health (history|software|recipe)', | |
| r'(food allergy|diet recipe|patient story|medical story)' | |
| ] | |
| is_medical_exception = any(re.search(pattern, message, re.IGNORECASE) for pattern in medical_exceptions) | |
| if is_non_medical and not is_medical_exception: | |
| return jsonify({ | |
| 'response': "I'm a medical assistant designed to provide health-related information. I'm not able to help with programming, cooking, or other non-medical topics. If you have any questions about health, medicine, symptoms, or wellness, I'd be happy to assist you! π", | |
| 'timestamp': time.time() | |
| }) | |
| # Generate medical response | |
| response = generate_response( | |
| message, | |
| max_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p) | |
| ) | |
| # Return the response | |
| return jsonify({ | |
| 'response': response, | |
| 'timestamp': time.time(), | |
| 'model_info': get_model_info() | |
| }) | |
| except Exception as e: | |
| print(f"Error in chat endpoint: {str(e)}") | |
| return jsonify({ | |
| 'error': 'Internal server error', | |
| 'details': str(e) | |
| }), 500 | |
| def stream_chat(): | |
| """Streaming chat endpoint""" | |
| try: | |
| data = request.get_json() | |
| if not data or 'message' not in data: | |
| return jsonify({'error': 'No message provided'}), 400 | |
| message = data['message'].strip() | |
| if not message: | |
| return jsonify({'error': 'Empty message'}), 400 | |
| def generate_stream(): | |
| try: | |
| # Get parameters | |
| max_tokens = data.get('max_tokens', GENERATION_DEFAULTS['max_new_tokens']) | |
| temperature = data.get('temperature', GENERATION_DEFAULTS['temperature']) | |
| top_p = data.get('top_p', GENERATION_DEFAULTS['top_p']) | |
| # Generate response in chunks | |
| response = generate_response( | |
| message, | |
| max_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p) | |
| ) | |
| # Stream the response word by word | |
| words = response.split() | |
| for i, word in enumerate(words): | |
| chunk_data = { | |
| 'chunk': word + (' ' if i < len(words) - 1 else ''), | |
| 'status': 'streaming' | |
| } | |
| yield f"data: {json.dumps(chunk_data)}\n\n" | |
| time.sleep(0.05) # Small delay for streaming effect | |
| # Send completion signal | |
| end_data = { | |
| 'complete': True, | |
| 'fullResponse': response | |
| } | |
| yield f"event: end\ndata: {json.dumps(end_data)}\n\n" | |
| except Exception as e: | |
| error_data = { | |
| 'error': 'Stream error', | |
| 'details': str(e) | |
| } | |
| yield f"event: error\ndata: {json.dumps(error_data)}\n\n" | |
| return Response( | |
| generate_stream(), | |
| content_type='text/event-stream', | |
| headers={ | |
| 'Cache-Control': 'no-cache', | |
| 'Connection': 'keep-alive', | |
| 'Access-Control-Allow-Origin': '*', | |
| 'Access-Control-Allow-Headers': 'Content-Type, Authorization' | |
| } | |
| ) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| if __name__ == "__main__": | |
| # For local development | |
| port = int(os.environ.get("API_PORT", 8000)) | |
| print(f"π Starting API server on port {port}") | |
| app.run(host="0.0.0.0", port=port, debug=False) | |