medin007's picture
Update app.py
8049b3b verified
import spaces
import gradio as gr
import torch
import traceback
from peft import AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,TextStreamer
# Detectar dispositivo (GPU o CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("¿GPU disponible?", torch.cuda.is_available())
print("Número de GPUs:", torch.cuda.device_count())
# Configuración del modelo
model_name = "Projener/AIPlannerModel_MISTRAL-7B"
# Configuración de cuantización
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Carga en 4 bits
bnb_4bit_compute_dtype=torch.float16 # Tipo de dato en GPU (FP16)
)
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoPeftModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map={"": 0},
).to(device)
except Exception as e:
model = None
error_message = f"Error al cargar el modelo: {str(e)}"
# Plantilla del prompt
alpaca_prompt = """
### Instruction:
{}
### Input:
{}
### Response:
{}
"""
# Función para generar texto, decorada con @spaces.GPU
@spaces.GPU(duration=120) #, retry=2, memory=6, device=1)
def generate_response_stream(start_date, end_date, total_duration, total_power, max_new_tokens):
if model is None:
yield "El modelo no pudo cargarse. Verifica la configuración de tu entorno."
return
# Instrucción y características del proyecto
instruction = (
"Generate a project plan for constructing a photovoltaic plant always including 100 tasks. "
"Each task has a fixed name and order. Calculate start and end dates for each task, based on the project characteristics provided, "
"ensuring all task durations are in working days only."
)
project_characteristics = (
f"Project Characteristics:\n"
f"- start_date: {start_date}\n"
f"- end_date: {end_date}\n"
f"- total_duration: {total_duration} working days\n"
f"- total_power: {total_power} MW"
)
# Prefijo fijo para la entrada
context_prefix = f"### Instruction:\n{instruction}\n### Input:\n{project_characteristics}\n### Response:\n"
# Generar los tensores iniciales
inputs = tokenizer([context_prefix], return_tensors="pt").to(device)
try:
# Generar texto usando el valor dinámico de max_new_tokens
outputs = model.generate(
inputs["input_ids"],
max_new_tokens=max_new_tokens, # Usar el valor del formulario
use_cache=True
)
# Decodificar la salida completa
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extraer la sección "### Response:"
if "### Response:" in decoded_output:
response = decoded_output.split("### Response:")[1].strip()
else:
response = "No se encontró una respuesta válida en la salida."
yield response[:22400] # Limitar a 22,400 caracteres si es necesario
except Exception as e:
import traceback
error_details = traceback.format_exc()
yield f"Error durante la generación:\n{error_details}"
with gr.Blocks() as demo:
gr.Markdown("## AIPlanner Mistral 7B")
# Primera fila: Entrada principal
with gr.Row():
start_date = gr.Textbox(label="Fecha de Inicio", value="2022-04-18")
end_date = gr.Textbox(label="Fecha de Fin", value="2023-08-18")
# Segunda fila: Más configuraciones
with gr.Row():
total_duration = gr.Number(label="Duración Total (días laborales)", value=350)
total_power = gr.Number(label="Potencia Total (MW)", value=400.0)
# Tercera fila: Configuración avanzada
with gr.Row():
max_new_tokens = gr.Slider(
label="Tokens máximos por generación",
minimum=64,
maximum=2048,
step=64,
value=128, # Valor inicial
interactive=True
)
# Botón y salida
generate_button = gr.Button("Generar Planificación")
output_box = gr.Textbox(label="Respuesta Generada", lines=15)
# Conexión de la función a los inputs y outputs
generate_button.click(
fn=generate_response_stream,
inputs=[start_date, end_date, total_duration, total_power, max_new_tokens],
outputs=output_box
)
demo.launch(share=True)