| from fastapi import FastAPI, File, UploadFile, Form, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| import torch |
| from transformers import DonutProcessor, VisionEncoderDecoderModel, pipeline |
| from PIL import Image, ImageOps, ImageEnhance |
| import easyocr |
| import pytesseract |
| import numpy as np |
| import os, re |
| import requests |
|
|
| app = FastAPI() |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| donut_processor = DonutProcessor.from_pretrained("chinmays18/medical-prescription-ocr") |
| donut_model = VisionEncoderDecoderModel.from_pretrained("chinmays18/medical-prescription-ocr").to(device) |
| reader = easyocr.Reader(['en']) |
| humadex_pipe = pipeline("token-classification", model="HUMADEX/english_medical_ner", aggregation_strategy="simple") |
| medner_pipe = pipeline("token-classification", model="blaze999/Medical-NER", aggregation_strategy="simple") |
| biogpt_pipe = pipeline("text-generation", model="microsoft/BioGPT-Large-PubMedQA") |
|
|
| def advanced_preprocess(image_path): |
| img = Image.open(image_path).convert('L') |
| img = img.resize((960,1280)) |
| img = ImageOps.autocontrast(img) |
| img = ImageEnhance.Contrast(img).enhance(2) |
| npimg = np.array(img) |
| npimg = np.where(npimg < 128, 0, 255).astype(np.uint8) |
| bin_img = Image.fromarray(npimg) |
| return bin_img.convert("RGB") |
|
|
| def clean_text(text): |
| |
| text = re.sub(r'<[^>]+>', '', text) |
| |
| text = re.sub(r'[▃▅▇█━—]+', '', text) |
| |
| text = re.sub(r'[•◆■★●]', '', text) |
| |
| text = text.replace('\n', ' ').replace('\f', ' ') |
| |
| text = re.sub(r'[^A-Za-z0-9\s\-/\(\)\.,:%]', '', text) |
| text = re.sub(' +', ' ', text) |
| |
| text = re.sub(r'(\d)\s+(\d)', r'\1\2', text) |
| return text.strip() |
|
|
| def extract_drugs_and_dose(text): |
| drugs = re.findall(r'(SYP|TAB|CAP|SYRUP|INJECTION|DROPS|INHALER|MEFTAL[- ]?P|CALPOL|DELCON|LEVOLIN)[\w\-\/\(\)]*', text, re.I) |
| doses = re.findall(r'\d+(\.\d+)?\s*(ml|mg|g|mcg|tablet|cap|puff|dose|drops)', text, re.I) |
| frequency = re.findall(r'(qc[h]?|q6h|tds|t[.]?d[.]?s[.]?|qds|b[.]?d[.]?|bd|sos|daily|once|twice|x\s*\d+d)', text, re.I) |
| doses = set([d[0]+d[1] if d[0] else d[1] for d in doses]) |
| return set(drugs), doses, set(frequency) |
|
|
| @app.post("/api/prescription") |
| async def prescription(file: UploadFile = File(...)): |
| filepath = f"temp_{file.filename}" |
| with open(filepath, "wb") as f: |
| f.write(await file.read()) |
| img = advanced_preprocess(filepath) |
|
|
| pixel_values = donut_processor(images=img, return_tensors="pt").pixel_values.to(device) |
| task_prompt = "<s_ocr>" |
| decoder_input_ids = donut_processor.tokenizer(task_prompt, return_tensors="pt").input_ids.to(device) |
| generated_ids = donut_model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=512) |
| donut_text = donut_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
| easy_text = "\n".join([t[1] for t in reader.readtext(filepath)]) |
| tess_text = pytesseract.image_to_string(img) |
| texts = [donut_text, easy_text, tess_text] |
| best_text = max(texts, key=lambda t: len(set(t.strip().split()))) |
| cleaned = clean_text(best_text) |
| humadex_ents = humadex_pipe(cleaned) |
| medner_ents = medner_pipe(cleaned) |
| regex_drugs, regex_doses, regex_freqs = extract_drugs_and_dose(cleaned) |
| out_drugs = set([ent.get('word','') for ent in humadex_ents if 'DRUG' in ent.get('entity_group','').upper()]) | \ |
| set([ent.get('word','') for ent in medner_ents if 'DRUG' in ent.get('entity_group','').upper()]) | regex_drugs |
| out_doses = set([ent.get('word','') for ent in humadex_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | \ |
| set([ent.get('word','') for ent in medner_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | regex_doses |
| out_freqs = set([ent.get('word','') for ent in humadex_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | \ |
| set([ent.get('word','') for ent in medner_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | regex_freqs |
|
|
| os.remove(filepath) |
|
|
| return { |
| "ocr_text": cleaned, |
| "drugs": list(out_drugs), |
| "doses": list(out_doses), |
| "frequencies": list(out_freqs), |
| } |
|
|
| @app.post("/api/chat") |
| async def chat(message: str = Form(...)): |
| |
| result = biogpt_pipe(message, max_new_tokens=400)[0]["generated_text"] |
| |
| |
| result = clean_text(result) |
| |
| return {"response": result} |
|
|
| D_ID_API_KEY = os.environ.get("D_ID_API_KEY") |
|
|
| @app.post("/api/did_talk") |
| async def did_talk(request: Request): |
| body = await request.json() |
| text = body["text"] |
| image_url = body["image_url"] |
|
|
| |
| text = clean_text(text) |
| text = text[:500] |
|
|
| headers = { |
| "Authorization": f"Basic {D_ID_API_KEY}:", |
| "Content-Type": "application/json" |
| } |
| payload = { |
| "script": { |
| "type": "text", |
| "input": text, |
| "provider": {"type": "microsoft", "voice_id": "en-US-GuyNeural"} |
| }, |
| "config": {"result_format": "mp4"}, |
| "source_url": image_url |
| } |
|
|
| resp = requests.post("https://api.d-id.com/talks", headers=headers, json=payload) |
| print(f"D-ID create response: {resp.status_code} - {resp.text}") |
| if not resp.ok: |
| return {"error": "Failed to create D-ID talk", "details": resp.text} |
| talk_id = resp.json()["id"] |
| print(f"D-ID talk_id: {talk_id}") |
|
|
| for i in range(20): |
| import time; time.sleep(3) |
| status = requests.get(f"https://api.d-id.com/talks/{talk_id}", headers=headers) |
| status_data = status.json() |
| print(f"Poll #{i+1}: {status_data}") |
| if "result_url" in status_data and status_data["result_url"]: |
| print(f"Video ready: {status_data['result_url']}") |
| return {"video_url": status_data["result_url"]} |
| |
| print("Video generation timed out") |
| return {"error": "Timed out or video not ready"} |
|
|