from fastapi import FastAPI, File, UploadFile, Form, Request from fastapi.middleware.cors import CORSMiddleware import torch from transformers import DonutProcessor, VisionEncoderDecoderModel, pipeline from PIL import Image, ImageOps, ImageEnhance import easyocr import pytesseract import numpy as np import os, re import requests app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") donut_processor = DonutProcessor.from_pretrained("chinmays18/medical-prescription-ocr") donut_model = VisionEncoderDecoderModel.from_pretrained("chinmays18/medical-prescription-ocr").to(device) reader = easyocr.Reader(['en']) humadex_pipe = pipeline("token-classification", model="HUMADEX/english_medical_ner", aggregation_strategy="simple") medner_pipe = pipeline("token-classification", model="blaze999/Medical-NER", aggregation_strategy="simple") biogpt_pipe = pipeline("text-generation", model="microsoft/BioGPT-Large-PubMedQA") def advanced_preprocess(image_path): img = Image.open(image_path).convert('L') img = img.resize((960,1280)) img = ImageOps.autocontrast(img) img = ImageEnhance.Contrast(img).enhance(2) npimg = np.array(img) npimg = np.where(npimg < 128, 0, 255).astype(np.uint8) bin_img = Image.fromarray(npimg) return bin_img.convert("RGB") def clean_text(text): # Remove <...> tags, e.g., , < /FREETEXT >, < / ABSTRACT > text = re.sub(r'<[^>]+>', '', text) # Remove common bar/block Unicode, including standalone ones text = re.sub(r'[▃▅▇█━—]+', '', text) # Remove common markdown and odd punctuation text = re.sub(r'[•◆■★●]', '', text) # Collapse weird newlines/extra whitespace text = text.replace('\n', ' ').replace('\f', ' ') # Only keep useful punctuation and collapse spaces text = re.sub(r'[^A-Za-z0-9\s\-/\(\)\.,:%]', '', text) text = re.sub(' +', ' ', text) # Remove spaces between digits text = re.sub(r'(\d)\s+(\d)', r'\1\2', text) return text.strip() def extract_drugs_and_dose(text): drugs = re.findall(r'(SYP|TAB|CAP|SYRUP|INJECTION|DROPS|INHALER|MEFTAL[- ]?P|CALPOL|DELCON|LEVOLIN)[\w\-\/\(\)]*', text, re.I) doses = re.findall(r'\d+(\.\d+)?\s*(ml|mg|g|mcg|tablet|cap|puff|dose|drops)', text, re.I) frequency = re.findall(r'(qc[h]?|q6h|tds|t[.]?d[.]?s[.]?|qds|b[.]?d[.]?|bd|sos|daily|once|twice|x\s*\d+d)', text, re.I) doses = set([d[0]+d[1] if d[0] else d[1] for d in doses]) return set(drugs), doses, set(frequency) @app.post("/api/prescription") async def prescription(file: UploadFile = File(...)): filepath = f"temp_{file.filename}" with open(filepath, "wb") as f: f.write(await file.read()) img = advanced_preprocess(filepath) pixel_values = donut_processor(images=img, return_tensors="pt").pixel_values.to(device) task_prompt = "<s_ocr>" decoder_input_ids = donut_processor.tokenizer(task_prompt, return_tensors="pt").input_ids.to(device) generated_ids = donut_model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=512) donut_text = donut_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() easy_text = "\n".join([t[1] for t in reader.readtext(filepath)]) tess_text = pytesseract.image_to_string(img) texts = [donut_text, easy_text, tess_text] best_text = max(texts, key=lambda t: len(set(t.strip().split()))) cleaned = clean_text(best_text) humadex_ents = humadex_pipe(cleaned) medner_ents = medner_pipe(cleaned) regex_drugs, regex_doses, regex_freqs = extract_drugs_and_dose(cleaned) out_drugs = set([ent.get('word','') for ent in humadex_ents if 'DRUG' in ent.get('entity_group','').upper()]) | \ set([ent.get('word','') for ent in medner_ents if 'DRUG' in ent.get('entity_group','').upper()]) | regex_drugs out_doses = set([ent.get('word','') for ent in humadex_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | \ set([ent.get('word','') for ent in medner_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | regex_doses out_freqs = set([ent.get('word','') for ent in humadex_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | \ set([ent.get('word','') for ent in medner_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | regex_freqs os.remove(filepath) return { "ocr_text": cleaned, "drugs": list(out_drugs), "doses": list(out_doses), "frequencies": list(out_freqs), } @app.post("/api/chat") async def chat(message: str = Form(...)): # Query BioGPT result = biogpt_pipe(message, max_new_tokens=400)[0]["generated_text"] # ✅ CLEAN THE BIOGPT OUTPUT result = clean_text(result) return {"response": result} D_ID_API_KEY = os.environ.get("D_ID_API_KEY") @app.post("/api/did_talk") async def did_talk(request: Request): body = await request.json() text = body["text"] image_url = body["image_url"] # ✅ CLEAN TEXT FOR D-ID TTS text = clean_text(text) text = text[:500] # Limit to 500 chars for TTS stability headers = { "Authorization": f"Basic {D_ID_API_KEY}:", "Content-Type": "application/json" } payload = { "script": { "type": "text", "input": text, "provider": {"type": "microsoft", "voice_id": "en-US-GuyNeural"} }, "config": {"result_format": "mp4"}, "source_url": image_url } resp = requests.post("https://api.d-id.com/talks", headers=headers, json=payload) print(f"D-ID create response: {resp.status_code} - {resp.text}") if not resp.ok: return {"error": "Failed to create D-ID talk", "details": resp.text} talk_id = resp.json()["id"] print(f"D-ID talk_id: {talk_id}") for i in range(20): import time; time.sleep(3) status = requests.get(f"https://api.d-id.com/talks/{talk_id}", headers=headers) status_data = status.json() print(f"Poll #{i+1}: {status_data}") if "result_url" in status_data and status_data["result_url"]: print(f"Video ready: {status_data['result_url']}") return {"video_url": status_data["result_url"]} print("Video generation timed out") return {"error": "Timed out or video not ready"}