Spaces:
Sleeping
Sleeping
| from fastapi.staticfiles import StaticFiles | |
| import re | |
| import torch | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.responses import FileResponse | |
| import os | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import logging | |
| import matplotlib | |
| matplotlib.use("Agg") # Mode sans interface graphique | |
| logging.basicConfig(level=logging.INFO) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Autorise toutes les origines (à sécuriser en prod) | |
| allow_credentials=True, | |
| allow_methods=["*"], # Autorise toutes les méthodes (GET, POST, etc.) | |
| allow_headers=["*"], # Autorise tous les headers | |
| ) | |
| # Charger le modèle Hugging Face | |
| model_name = "Salesforce/codegen-350M-mono" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name).to(device) | |
| VALID_PLOTS = {"histplot", "scatterplot", "barplot", "lineplot", "boxplot"} | |
| async def generate_viz(file: UploadFile = File(...), query: str = Form(...)): | |
| try: | |
| if query not in VALID_PLOTS: | |
| return {"error": f"Type de graphique invalide. Choisissez parmi : {', '.join(VALID_PLOTS)}"} | |
| df = pd.read_excel(file.file) | |
| numeric_cols = df.select_dtypes(include=["number"]).columns | |
| if len(numeric_cols) < 2: | |
| return {"error": "Le fichier doit contenir au moins deux colonnes numériques."} | |
| x_col, y_col = numeric_cols[:2] | |
| # Contraintes spécifiques pour éviter l'erreur avec histplot | |
| if query == "histplot": | |
| prompt_y = "" | |
| else: | |
| prompt_y = f', y="{y_col}"' | |
| # Générer l'invite pour le modèle | |
| prompt = f""" | |
| ### Génère uniquement du code Python fonctionnel pour tracer un {query} avec Matplotlib et Seaborn ### | |
| # Contraintes : | |
| # - Utilise 'df' sans recréer de nouvelles données | |
| # - Axe X : '{x_col}' | |
| # - Enregistre le graphique sous 'plot.png' | |
| # - Ne génère que du code Python valide, sans texte explicatif | |
| # Contraintes spécifiques pour sns.histplot : | |
| # - N'inclut pas "y=" car histplot ne supporte qu'un axe | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| plt.figure(figsize=(8,6)) | |
| sns.{query}(data=df, x="{x_col}"{prompt_y}) | |
| plt.savefig("plot.png") | |
| plt.close() | |
| """ | |
| # Génération du code | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| outputs = model.generate(**inputs, max_new_tokens=120, pad_token_id=tokenizer.eos_token_id) | |
| generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| # Nettoyage du code | |
| generated_code = re.sub(r"(import matplotlib.pyplot as plt\nimport seaborn as sns\n)+", "import matplotlib.pyplot as plt\nimport seaborn as sns\n", generated_code) | |
| if generated_code.strip().endswith("sns."): | |
| generated_code = generated_code.rsplit("\n", 1)[0] # Supprime la dernière ligne incomplète | |
| print("🔹 Code généré par l'IA :\n", generated_code) | |
| # Vérification syntaxique avant exécution | |
| try: | |
| compile(generated_code, "<string>", "exec") | |
| except SyntaxError as e: | |
| return {"error": f"Erreur de syntaxe détectée : {e}\nCode généré :\n{generated_code}"} | |
| # Vérification des données | |
| print(df.head()) # Affiche les premières lignes du dataframe | |
| print(df.dtypes) # Vérifie les types de colonnes | |
| print(f"Colonne '{x_col}' - Valeurs uniques:", df[x_col].unique()) | |
| if df.empty or x_col not in df.columns or df[x_col].isnull().all(): | |
| return {"error": f"La colonne '{x_col}' est absente ou ne contient pas de données valides."} | |
| # Exécution du code généré | |
| exec_env = {"df": df, "plt": plt, "sns": sns, "pd": pd} | |
| exec(generated_code, exec_env) | |
| # Vérification de l'image générée | |
| img_path = "plot.png" | |
| if not os.path.exists(img_path): | |
| return {"error": "Le fichier plot.png n'a pas été généré."} | |
| if os.path.getsize(img_path) == 0: | |
| return {"error": "Le fichier plot.png est vide."} | |
| plt.close() | |
| return FileResponse(img_path, media_type="image/png") | |
| except Exception as e: | |
| return {"error": f"Erreur lors de la génération du graphique : {str(e)}"} | |
| # ✅ Déplace ici le montage des fichiers statiques | |
| app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
| # Redirection vers index.html | |
| async def root(): | |
| return RedirectResponse(url="/index.html") | |