File size: 4,767 Bytes
b678e5e
782a505
b678e5e
 
 
 
 
 
 
 
9d1c6dd
907d6a2
6355482
b198707
 
eb1fe2e
 
761d58c
9d1c6dd
 
b678e5e
9d1c6dd
b678e5e
 
9d1c6dd
 
b678e5e
 
 
907d6a2
b678e5e
 
907d6a2
b678e5e
907d6a2
b678e5e
 
9d1c6dd
b678e5e
 
907d6a2
b678e5e
907d6a2
b678e5e
 
 
907d6a2
b678e5e
9d1c6dd
b678e5e
 
 
907d6a2
b678e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
907d6a2
b678e5e
 
 
 
907d6a2
b678e5e
 
907d6a2
b678e5e
 
 
907d6a2
b678e5e
 
 
 
 
 
907d6a2
b678e5e
 
907d6a2
 
b678e5e
907d6a2
 
 
d75cd79
9d1c6dd
907d6a2
 
 
9d1c6dd
 
907d6a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

from fastapi.staticfiles import StaticFiles
import re
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse
import os
from fastapi.middleware.cors import CORSMiddleware
import logging
import matplotlib
matplotlib.use("Agg")  # Mode sans interface graphique

logging.basicConfig(level=logging.INFO)
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Autorise toutes les origines (à sécuriser en prod)
    allow_credentials=True,
    allow_methods=["*"],  # Autorise toutes les méthodes (GET, POST, etc.)
    allow_headers=["*"],  # Autorise tous les headers
)

# Charger le modèle Hugging Face
model_name = "Salesforce/codegen-350M-mono"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

VALID_PLOTS = {"histplot", "scatterplot", "barplot", "lineplot", "boxplot"}

@app.post("/generate_viz/")
async def generate_viz(file: UploadFile = File(...), query: str = Form(...)):
    try:
        if query not in VALID_PLOTS:
            return {"error": f"Type de graphique invalide. Choisissez parmi : {', '.join(VALID_PLOTS)}"}

        df = pd.read_excel(file.file)

        numeric_cols = df.select_dtypes(include=["number"]).columns
        if len(numeric_cols) < 2:
            return {"error": "Le fichier doit contenir au moins deux colonnes numériques."}

        x_col, y_col = numeric_cols[:2]

        # Contraintes spécifiques pour éviter l'erreur avec histplot
        if query == "histplot":
            prompt_y = ""
        else:
            prompt_y = f', y="{y_col}"'

        # Générer l'invite pour le modèle
        prompt = f"""
### Génère uniquement du code Python fonctionnel pour tracer un {query} avec Matplotlib et Seaborn ###
# Contraintes :
# - Utilise 'df' sans recréer de nouvelles données
# - Axe X : '{x_col}'
# - Enregistre le graphique sous 'plot.png'
# - Ne génère que du code Python valide, sans texte explicatif
# Contraintes spécifiques pour sns.histplot :
# - N'inclut pas "y=" car histplot ne supporte qu'un axe

import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(8,6))
sns.{query}(data=df, x="{x_col}"{prompt_y})

plt.savefig("plot.png")
plt.close()
"""

        # Génération du code
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        outputs = model.generate(**inputs, max_new_tokens=120, pad_token_id=tokenizer.eos_token_id)
        generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

        # Nettoyage du code
        generated_code = re.sub(r"(import matplotlib.pyplot as plt\nimport seaborn as sns\n)+", "import matplotlib.pyplot as plt\nimport seaborn as sns\n", generated_code)
        if generated_code.strip().endswith("sns."):
            generated_code = generated_code.rsplit("\n", 1)[0]  # Supprime la dernière ligne incomplète

        print("🔹 Code généré par l'IA :\n", generated_code)

        # Vérification syntaxique avant exécution
        try:
            compile(generated_code, "<string>", "exec")
        except SyntaxError as e:
            return {"error": f"Erreur de syntaxe détectée : {e}\nCode généré :\n{generated_code}"}

        # Vérification des données
        print(df.head())  # Affiche les premières lignes du dataframe
        print(df.dtypes)  # Vérifie les types de colonnes
        print(f"Colonne '{x_col}' - Valeurs uniques:", df[x_col].unique())

        if df.empty or x_col not in df.columns or df[x_col].isnull().all():
            return {"error": f"La colonne '{x_col}' est absente ou ne contient pas de données valides."}

        # Exécution du code généré
        exec_env = {"df": df, "plt": plt, "sns": sns, "pd": pd}
        exec(generated_code, exec_env)

        # Vérification de l'image générée
        img_path = "plot.png"
        if not os.path.exists(img_path):
            return {"error": "Le fichier plot.png n'a pas été généré."}
        if os.path.getsize(img_path) == 0:
            return {"error": "Le fichier plot.png est vide."}

        plt.close()
        return FileResponse(img_path, media_type="image/png")

    except Exception as e:
        return {"error": f"Erreur lors de la génération du graphique : {str(e)}"}


        # ✅ Déplace ici le montage des fichiers statiques
app.mount("/", StaticFiles(directory="static", html=True), name="static")



# Redirection vers index.html
@app.get("/")
async def root():
    return RedirectResponse(url="/index.html")