DAC / app.py
Mattimax's picture
Upload 3 files
b9150bd verified
import logging
import torch
import json
import os
from flask import Flask, render_template, request, Response, stream_with_context
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread
import random
import numpy as np
logging.basicConfig(level=logging.INFO)
app = Flask(__name__)
MODEL_NAME = "Mattimax/DATA-AI_Chat_3_0.5B"
# Controlla se c'è una GPU, altrimenti usa CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Imposta seed per reproducibilità
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device == "cuda":
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
# Configurazione per caricare il modello in modo efficiente
bnb_config = None
if device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
logging.info("Caricamento tokenizer e modello: %s (device=%s)", MODEL_NAME, device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Leggi il chat template dal tokenizer config
chat_template = None
try:
from transformers.models.auto.tokenization_auto import get_tokenizer_config
config_dict = get_tokenizer_config(MODEL_NAME)
chat_template = config_dict.get("chat_template")
logging.info("Chat template caricato: %s", chat_template[:100] if chat_template else "Non disponibile")
except Exception as e:
logging.warning("Impossibile caricare chat_template: %s", e)
# Fallback: template di default semplice
chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% else %}{{ message['role'] }}: {{ message['content'] }}\n{% endif %}{% endfor %}"
# assicurati che esista un pad_token
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
if device == "cuda":
# usa device_map auto per posizionare i pesi sulla GPU in modo efficiente
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto"
)
else:
# caricamento su CPU (può essere lento) - evita .to() per device_map compatibile
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.to("cpu")
model.eval() # modo valutazione per stabilità
# System prompt per guidare il comportamento
SYSTEM_PROMPT = """Tu sei DAC, un assistente intelligente e amichevole. Rispondi in modo coerente, chiaro e utile.
Se non conosci la risposta, ammettilo con sincerità. Mantieni il tono professionale ma accessibile."""
@app.route('/')
def index():
return render_template('index.html')
@app.route('/chat', methods=['POST'])
def chat():
data = request.json or {}
user_input = data.get("message", "")
if not user_input:
return Response(json.dumps({"error": "empty message"}), status=400)
# Costruisci il prompt con system message e chat template
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_input}
]
# Applica il chat template se disponibile
if chat_template and hasattr(tokenizer, 'apply_chat_template'):
try:
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
logging.warning("Errore applicando chat_template: %s, fallback a prompt semplice", e)
prompt_text = f"System: {SYSTEM_PROMPT}\nUser: {user_input}\nAssistant:"
else:
# Fallback semplice
prompt_text = f"System: {SYSTEM_PROMPT}\nUser: {user_input}\nAssistant:"
logging.info("Prompt generato: %s", prompt_text[:200])
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
inputs = tokenizer(prompt_text, return_tensors="pt")
# sposta gli input sulla GPU se disponibile
if device == "cuda":
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Parametri migliorati per stabilità e qualità
generation_kwargs = dict(
input_ids=inputs.get("input_ids"),
attention_mask=inputs.get("attention_mask"),
streamer=streamer,
max_new_tokens=2048,
temperature=0.5, # ridotto per più stabilità
do_sample=True,
top_p=0.80, # nucleus sampling per evitare token improbabili
top_k=40, # limita i candidati ai top-k
repetition_penalty=1.2, # penalizza ripetizioni
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=4, # evita n-grammi ripetuti
early_stopping=False,
)
def run_generate():
try:
with torch.no_grad():
model.generate(**generation_kwargs)
except Exception as e:
logging.exception("Errore durante la generazione:")
thread = Thread(target=run_generate)
thread.daemon = True
thread.start()
def generate():
try:
# yield streaming tokens in formato SSE
for new_text in streamer:
yield f"data: {json.dumps({'token': new_text})}\n\n"
except GeneratorExit:
logging.info("Client disconnected dalla stream")
except Exception:
logging.exception("Errore nello stream")
headers = {
'Cache-Control': 'no-cache',
'X-Accel-Buffering': 'no'
}
return Response(stream_with_context(generate()), mimetype='text/event-stream', headers=headers)
if __name__ == "__main__":
# HF Spaces richiede tassativamente la porta 7860
logging.info("Avvio app su 0.0.0.0:7860")
app.run(host='0.0.0.0', port=7860, threaded=True)