from fastapi.staticfiles import StaticFiles import re import torch import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from transformers import AutoTokenizer, AutoModelForCausalLM from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import FileResponse import os from fastapi.middleware.cors import CORSMiddleware import logging import matplotlib matplotlib.use("Agg") # Mode sans interface graphique logging.basicConfig(level=logging.INFO) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # Autorise toutes les origines (à sécuriser en prod) allow_credentials=True, allow_methods=["*"], # Autorise toutes les méthodes (GET, POST, etc.) allow_headers=["*"], # Autorise tous les headers ) # Charger le modèle Hugging Face model_name = "Salesforce/codegen-350M-mono" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name).to(device) VALID_PLOTS = {"histplot", "scatterplot", "barplot", "lineplot", "boxplot"} @app.post("/generate_viz/") async def generate_viz(file: UploadFile = File(...), query: str = Form(...)): try: if query not in VALID_PLOTS: return {"error": f"Type de graphique invalide. Choisissez parmi : {', '.join(VALID_PLOTS)}"} df = pd.read_excel(file.file) numeric_cols = df.select_dtypes(include=["number"]).columns if len(numeric_cols) < 2: return {"error": "Le fichier doit contenir au moins deux colonnes numériques."} x_col, y_col = numeric_cols[:2] # Contraintes spécifiques pour éviter l'erreur avec histplot if query == "histplot": prompt_y = "" else: prompt_y = f', y="{y_col}"' # Générer l'invite pour le modèle prompt = f""" ### Génère uniquement du code Python fonctionnel pour tracer un {query} avec Matplotlib et Seaborn ### # Contraintes : # - Utilise 'df' sans recréer de nouvelles données # - Axe X : '{x_col}' # - Enregistre le graphique sous 'plot.png' # - Ne génère que du code Python valide, sans texte explicatif # Contraintes spécifiques pour sns.histplot : # - N'inclut pas "y=" car histplot ne supporte qu'un axe import matplotlib.pyplot as plt import seaborn as sns plt.figure(figsize=(8,6)) sns.{query}(data=df, x="{x_col}"{prompt_y}) plt.savefig("plot.png") plt.close() """ # Génération du code inputs = tokenizer(prompt, return_tensors="pt").to(device) outputs = model.generate(**inputs, max_new_tokens=120, pad_token_id=tokenizer.eos_token_id) generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() # Nettoyage du code generated_code = re.sub(r"(import matplotlib.pyplot as plt\nimport seaborn as sns\n)+", "import matplotlib.pyplot as plt\nimport seaborn as sns\n", generated_code) if generated_code.strip().endswith("sns."): generated_code = generated_code.rsplit("\n", 1)[0] # Supprime la dernière ligne incomplète print("🔹 Code généré par l'IA :\n", generated_code) # Vérification syntaxique avant exécution try: compile(generated_code, "", "exec") except SyntaxError as e: return {"error": f"Erreur de syntaxe détectée : {e}\nCode généré :\n{generated_code}"} # Vérification des données print(df.head()) # Affiche les premières lignes du dataframe print(df.dtypes) # Vérifie les types de colonnes print(f"Colonne '{x_col}' - Valeurs uniques:", df[x_col].unique()) if df.empty or x_col not in df.columns or df[x_col].isnull().all(): return {"error": f"La colonne '{x_col}' est absente ou ne contient pas de données valides."} # Exécution du code généré exec_env = {"df": df, "plt": plt, "sns": sns, "pd": pd} exec(generated_code, exec_env) # Vérification de l'image générée img_path = "plot.png" if not os.path.exists(img_path): return {"error": "Le fichier plot.png n'a pas été généré."} if os.path.getsize(img_path) == 0: return {"error": "Le fichier plot.png est vide."} plt.close() return FileResponse(img_path, media_type="image/png") except Exception as e: return {"error": f"Erreur lors de la génération du graphique : {str(e)}"} # ✅ Déplace ici le montage des fichiers statiques app.mount("/", StaticFiles(directory="static", html=True), name="static") # Redirection vers index.html @app.get("/") async def root(): return RedirectResponse(url="/index.html")