asm / main.py
fatimazahra22-4's picture
Create main.py
109fccb verified
from fastapi import FastAPI, Request, UploadFile, Form, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from transformers import pipeline
import pandas as pd
import fitz # PyMuPDF
from PIL import Image
import pytesseract
import re
import matplotlib.pyplot as plt
import seaborn as sns
import tempfile
import base64
import io
# Initialisation de l'app FastAPI
app = FastAPI()
# Templates et Static files pour le frontend
templates = Jinja2Templates(directory="templates")
# Variables globales pour les modèles de Hugging Face
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
# Fonction pour traiter les fichiers
async def process_uploaded_file(file: UploadFile):
"""Valider et traiter les fichiers téléchargés"""
content = await file.read()
file_ext = file.filename.split('.')[-1].lower()
if file_ext not in {"pdf", "docx", "xlsx", "pptx", "jpg", "jpeg", "png"}:
raise HTTPException(status_code=400, detail="Unsupported file type.")
return file_ext, content
# Fonction pour extraire le texte d'un fichier
def extract_text(content: bytes, file_ext: str) -> str:
"""Extraire du texte en fonction de l'extension du fichier"""
if file_ext == "pdf":
doc = fitz.open(stream=content, filetype="pdf")
return "\n".join(page.get_text("text") for page in doc)
elif file_ext == "docx":
from docx import Document
doc = Document(io.BytesIO(content))
return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
elif file_ext == "xlsx":
df = pd.read_excel(io.BytesIO(content))
return "\n".join(df.astype(str).values.flatten())
elif file_ext == "pptx":
from pptx import Presentation
ppt = Presentation(io.BytesIO(content))
text = []
for slide in ppt.slides:
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text.strip():
text.append(shape.text)
return "\n".join(text)
elif file_ext in {"jpg", "jpeg", "png"}:
image = Image.open(io.BytesIO(content))
text = pytesseract.image_to_string(image)
if text.strip():
return text
caption = image_captioner(image)
return caption[0]["generated_text"]
return ""
# Fonction pour générer un résumé du texte
def generate_summary(text: str) -> str:
"""Créer un résumé du texte"""
text = re.sub(r'\s+', ' ', text).strip()
chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
summaries = []
for chunk in chunks:
summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
summaries.append(summary)
return " ".join(summaries)
# Fonction pour répondre à une question basée sur le texte extrait
def generate_qa_answer(text: str, question: str) -> str:
"""Répondre à une question à partir du texte"""
result = qa_model(question=question, context=text)
return result["answer"]
# Fonction pour générer une visualisation à partir des données Excel
def generate_visualization(df: pd.DataFrame, chart_type: str, x_col: str, y_col: str) -> str:
"""Générer un graphique en fonction des données"""
plt.figure(figsize=(10, 6))
if chart_type == "line":
sns.lineplot(data=df, x=x_col, y=y_col)
elif chart_type == "bar":
sns.barplot(data=df, x=x_col, y=y_col)
elif chart_type == "scatter":
sns.scatterplot(data=df, x=x_col, y=y_col)
plt.tight_layout()
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
plt.savefig(tmpfile.name)
plt.close()
return tmpfile.name
@app.get("/", response_class=HTMLResponse)
async def form_page(request: Request):
"""Page d'accueil avec le formulaire d'upload"""
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/process", response_class=HTMLResponse)
async def process_file(
request: Request,
uploaded_file: UploadFile,
action: str = Form(...),
question: str = Form(None),
x_column: str = Form(None),
y_column: str = Form(None)
):
"""Traiter le fichier en fonction de l'action demandée"""
try:
file_ext, content = await process_uploaded_file(uploaded_file)
text = extract_text(content, file_ext)
if action == "resume":
summary = generate_summary(text)
result = f"Résumé généré : {summary}"
elif action == "qa":
if not question:
raise HTTPException(status_code=400, detail="Question required for QA.")
answer = generate_qa_answer(text, question)
result = f"Réponse à la question : {answer}"
elif action == "visualisation":
if not x_column or not y_column:
raise HTTPException(status_code=400, detail="x_column and y_column required for visualization.")
df = pd.read_excel(io.BytesIO(content))
image_path = generate_visualization(df, "line", x_column, y_column)
with open(image_path, "rb") as f:
img_bytes = f.read()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
result = f"<img src='data:image/png;base64,{img_base64}' alt='Visualisation'>"
else:
raise HTTPException(status_code=400, detail="Action inconnue")
return templates.TemplateResponse("result.html", {
"request": request,
"result": result
})
except Exception as e:
return HTTPException(status_code=500, detail=str(e))