RagImage / app.py
goldrode's picture
Update app.py
d153893 verified
import os
import subprocess
from fastapi import FastAPI
from sentence_transformers import SentenceTransformer
from fastapi.middleware.cors import CORSMiddleware
import fitz # PyMuPDF pour extraction du texte PDF
import faiss
import numpy as np
import requests
import pytesseract
from PIL import Image
import gradio as gr
# Installation automatique de Tesseract pour Windows
def install_tesseract_windows():
tesseract_installer_url = "https://github.com/tesseract-ocr/tesseract/releases/download/5.3.0/tesseract-5.3.0.20221214.exe"
installer_path = "tesseract_installer.exe"
try:
if not os.path.exists("C:\\Program Files\\Tesseract-OCR"):
print("Téléchargement de Tesseract OCR en cours...")
# Télécharger l'installateur via requests
response = requests.get(tesseract_installer_url, stream=True)
with open(installer_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("Tesseract OCR téléchargé. Installation en cours...")
# Exécuter l'installateur silencieusement
subprocess.run([installer_path, "/S"], shell=True, check=True)
# Configurer le chemin vers Tesseract
os.environ["TESSDATA_PREFIX"] = "C:\\Program Files\\Tesseract-OCR\\"
pytesseract.pytesseract.tesseract_cmd = "C:\\Program Files\\Tesseract-OCR\\tesseract.exe"
print("Tesseract OCR installé avec succès.")
else:
pytesseract.pytesseract.tesseract_cmd = "C:\\Program Files\\Tesseract-OCR\\tesseract.exe"
print("Tesseract OCR déjà installé.")
except Exception as e:
print(f"Erreur lors de l'installation de Tesseract OCR : {e}")
install_tesseract_windows()
# Configuration de l'API Gemini
GEMINI_API_KEY = "AIzaSyALDnGnCP4AuC3uX5dXcYugrfO89clRG9o"
GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent"
# Configuration FAISS et embeddings
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(EMBEDDING_MODEL)
INDEX_PATH = "medical_faiss_index"
documents = []
if os.path.exists(INDEX_PATH):
index = faiss.read_index(INDEX_PATH)
print("Index FAISS chargé avec succès.")
else:
index = faiss.IndexFlatL2(model.get_sentence_embedding_dimension())
print("Nouvel index FAISS créé.")
# Fonction d'extraction de texte
def extract_text_from_pdf(file_content: bytes):
pdf = fitz.open(stream=file_content, filetype="pdf")
paragraphs = []
for page in pdf:
text = page.get_text()
if text.strip():
paragraphs.extend([p.strip() for p in text.split("\n\n") if p.strip()])
return paragraphs
def extract_text_from_image(file_path):
return pytesseract.image_to_string(Image.open(file_path))
# Fonction pour ajouter les documents de référence
def add_medical_reference(file):
with open(file.name, "rb") as f:
file_content = f.read()
paragraphs = extract_text_from_pdf(file_content)
embeddings = model.encode(paragraphs)
index.add(np.array(embeddings, dtype="float32"))
documents.extend(paragraphs)
faiss.write_index(index, INDEX_PATH)
return "Référence médicale ajoutée avec succès."
# Fonction pour analyser un fichier (PDF ou Image)
def analyze_blood_test(file):
try:
# Extraire le texte
if file.name.endswith((".png", ".jpg", ".jpeg")):
extracted_text = extract_text_from_image(file.name)
elif file.name.endswith(".pdf"):
with open(file.name, "rb") as f:
file_content = f.read()
paragraphs = extract_text_from_pdf(file_content)
else:
return "Format non supporté. Utilisez un PDF ou une image."
if not paragraphs:
return "Aucun texte valide extrait."
# Recherche segmentée dans FAISS
responses = []
for paragraph in paragraphs:
relevant_docs = search_faiss(paragraph, k=3) # Recherche ciblée pour chaque segment
context = "\n".join(relevant_docs)
enriched_prompt = f"Voici un segment d'analyse :\n{paragraph}\n\nContexte pertinent :\n{context}"
gemini_response = call_gemini_api(enriched_prompt)
responses.append(f"Segment :\n{paragraph}\nRéponse générée :\n{gemini_response}\n")
return "\n\n".join(responses)
except Exception as e:
return f"Erreur : {str(e)}"
# Recherche FAISS
def search_faiss(query, k=5):
query_embedding = model.encode([query])
distances, indices = index.search(np.array(query_embedding, dtype="float32"), k)
return [documents[i] for i in indices[0] if i < len(documents)]
# Appel API Gemini
def call_gemini_api(prompt):
headers = {"Content-Type": "application/json"}
payload = {"contents": [{"parts": [{"text": prompt}]}]}
try:
response = requests.post(f"{GEMINI_API_URL}?key={GEMINI_API_KEY}", json=payload, headers=headers)
return response.json().get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "Pas de réponse.")
except Exception as e:
return f"Erreur API : {str(e)}"
# Interface Gradio
with gr.Blocks() as demo:
gr.Markdown("## Analyse Médicale avec RAG et Gemini")
with gr.Tab("Ajouter Références Médicales"):
ref_file = gr.File(label="Téléchargez un fichier PDF de référence médicale")
ref_output = gr.Textbox(label="Résultat")
ref_button = gr.Button("Ajouter")
ref_button.click(add_medical_reference, inputs=ref_file, outputs=ref_output)
with gr.Tab("Analyser un Résultat d'Analyse"):
test_file = gr.File(label="Téléchargez un fichier PDF ou image (JPG/PNG)")
analysis_output = gr.Textbox(label="Résultat")
analyze_button = gr.Button("Analyser")
analyze_button.click(analyze_blood_test, inputs=test_file, outputs=analysis_output)
demo.launch()