| |
|
|
| import os |
| import torch |
| import gc |
| import logging |
| from typing import Dict |
| from pydantic import BaseModel |
|
|
| |
| from transformers import ( |
| AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, |
| BitsAndBytesConfig, Gemma3ForConditionalGeneration |
| ) |
| from huggingface_hub import login |
| from peft import PeftModel |
| import warnings |
|
|
| |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| warnings.filterwarnings("ignore") |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| class GenerationRequest(BaseModel): |
| input_text: str |
| model_name: str |
|
|
| class GenerationResponse(BaseModel): |
| response: str |
|
|
| class MedicalKnowledgeTester: |
| def __init__(self): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"Using device: {self.device}") |
| |
| hf_token = os.environ.get("HF_TOKEN") |
| if hf_token: |
| login(token=hf_token) |
| logger.info("Logged in to Hugging Face using token from environment variable.") |
|
|
| self.models = {} |
| self.tokenizers = {} |
|
|
| |
| |
| |
| self.model_configs = { |
| "led-base": { |
| "model_type": "encoder-decoder", |
| "base_model": "allenai/led-base-16384", |
| "adapter_model": "ALQAMARI/led-base-sbar-summary-adapter", |
| "max_length": 4096, |
| "use_quantization": False, |
| }, |
| "gemma-3-12b-it": { |
| "model_type": "decoder", |
| "base_model": "google/gemma-3-12b-it", |
| "adapter_model": "ALQAMARI/gemma-3-12b-it-summary-adapter", |
| "max_length": 4096, |
| "use_quantization": True, |
| }, |
| |
| "medgemma-27b": { |
| "model_type": "decoder", |
| "base_model": "google/medgemma-27b-text-it", |
| "adapter_model": "ALQAMARI/medgemma-sbar-summary", |
| "max_length": 4096, |
| "use_quantization": True, |
| } |
| } |
| |
| self.GENERAL_TEMPLATE = """You are a versatile and highly skilled medical AI assistant. Your role is to provide accurate and helpful responses to medical inquiries. |
| - If the user provides a patient record, a long medical report, or text that requires summarization, your primary task is to summarize it concisely. Highlight the key findings, diagnoses, and recommendations in a clear format suitable for other medical professionals. |
| - If the user asks a direct question, provide a comprehensive and clear medical explanation. |
| - Analyze the user's input below and respond in the most appropriate manner, either as a summarizer or a knowledge expert. |
| |
| User Input: |
| {input_text} |
| |
| Your Response:""" |
|
|
| def load_model(self, model_name: str): |
| if model_name in self.models: |
| logger.info(f"Model '{model_name}' is already loaded.") |
| return |
|
|
| if model_name not in self.model_configs: |
| raise ValueError(f"Model {model_name} not supported.") |
| |
| config = self.model_configs[model_name] |
| logger.info(f"Loading {model_name}...") |
|
|
| model_kwargs = {"device_map": "auto", "trust_remote_code": True} |
| |
| if config["use_quantization"]: |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True |
| ) |
| model_kwargs["quantization_config"] = bnb_config |
| model_kwargs["torch_dtype"] = torch.bfloat16 |
| else: |
| model_kwargs["torch_dtype"] = torch.float16 |
|
|
| tokenizer = AutoTokenizer.from_pretrained(config["base_model"]) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "left" |
| |
| if config["model_type"] == "encoder-decoder": |
| base_model = AutoModelForSeq2SeqLM.from_pretrained(config["base_model"], **model_kwargs) |
| else: |
| base_model = AutoModelForCausalLM.from_pretrained(config["base_model"], **model_kwargs) |
| |
| try: |
| model = PeftModel.from_pretrained(base_model, config["adapter_model"]) |
| logger.info(f"Successfully loaded adapter from {config['adapter_model']}") |
| except Exception as e: |
| logger.error(f"Failed to load adapter: {e}. Using base model without adapter.") |
| model = base_model |
| |
| model.eval() |
| |
| self.models[model_name] = model |
| self.tokenizers[model_name] = tokenizer |
| logger.info(f"{model_name} loaded successfully.") |
|
|
| def generate_response(self, model_name: str, input_text: str) -> str: |
| if model_name not in self.models: |
| self.load_model(model_name) |
| |
| model = self.models[model_name] |
| tokenizer = self.tokenizers[model_name] |
| config = self.model_configs[model_name] |
| |
| prompt = self.GENERAL_TEMPLATE.format(input_text=input_text) |
| |
| if config["model_type"] == "decoder": |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device) |
| else: |
| inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=config["max_length"]).to(self.device) |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, max_new_tokens=512, do_sample=True, temperature=0.1, |
| pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.1 |
| ) |
| |
| if config["model_type"] == "decoder": |
| input_length = inputs.input_ids.shape[1] |
| generated_tokens = outputs[0][input_length:] |
| else: |
| generated_tokens = outputs[0] |
| |
| response = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| return response.strip() |
|
|
| app = FastAPI() |
| tester = MedicalKnowledgeTester() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| logger.info("Server starting up. Pre-loading default model...") |
| try: |
| |
| tester.load_model("medgemma-27b") |
| except Exception as e: |
| logger.error(f"Could not pre-load medgemma-27b model: {e}s") |
|
|
| @app.get("/") |
| def read_root(): |
| return {"status": "Medical AI API - I AM THE NEW VERSION"} |
|
|
| @app.post("/generate", response_model=GenerationResponse) |
| async def generate(request: GenerationRequest): |
| logger.info(f"Received request for model: {request.model_name}") |
| try: |
| response_text = tester.generate_response( |
| model_name=request.model_name, |
| input_text=request.input_text |
| ) |
| return GenerationResponse(response=response_text) |
| except Exception as e: |
| logger.error(f"Error during generation: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |