Mistral-7B-API / app.py
smartdigitalsolutions's picture
Update app.py
4eec20e verified
raw
history blame
8.12 kB
import os
import time
import json
import gradio as gr
from threading import Lock
from ctransformers import AutoModelForCausalLM
import fastapi
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
# Variabili globali
model = None
status_message = "Modello non ancora caricato"
MODEL_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
MODEL_FILE = "mistral-7b-instruct-v0.2.Q4_K_M.gguf" # Versione quantizzata per risparmiare memoria
MODEL_TYPE = "mistral"
MAX_NEW_TOKENS = 2048
MODEL_LOCK = Lock() # Per evitare richieste contemporanee che potrebbero causare OOM
# Definizioni dei modelli di dati (Pydantic)
class Message(BaseModel):
role: str
content: str
class CompletionRequest(BaseModel):
model: str
messages: List[Message]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.95
max_tokens: Optional[int] = 2048
stream: Optional[bool] = False
stop: Optional[List[str]] = None
class CompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, int]
# Funzioni di utilità
def format_chat_prompt(messages: List[Message]) -> str:
"""Formatta i messaggi nel formato atteso da Mistral Instruct."""
conversation = []
for message in messages:
if message.role == "system":
# Inserisce il messaggio di sistema come istruzione iniziale
conversation.append(f"<s>[INST] {message.content} [/INST]</s>")
elif message.role == "user":
conversation.append(f"<s>[INST] {message.content} [/INST]</s>")
elif message.role == "assistant":
conversation.append(f"<s>{message.content}</s>")
return "".join(conversation)
def load_model():
"""Carica il modello Mistral quantizzato."""
global model, status_message
try:
status_message = "Caricamento modello in corso..."
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
model_file=MODEL_FILE,
model_type=MODEL_TYPE,
context_length=4096,
threads=4 # Usa 4 thread per lasciare risorse al sistema
)
status_message = "Modello caricato con successo"
return True
except Exception as e:
status_message = f"Errore nel caricamento del modello: {str(e)}"
return False
def generate_response(prompt, temperature=0.7, top_p=0.95, max_tokens=MAX_NEW_TOKENS):
"""Genera una risposta dal modello."""
global model, status_message
if model is None:
if not load_model():
return status_message
with MODEL_LOCK: # Previene richieste parallele che potrebbero causare OOM
try:
result = model(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=1.1
)
return result
except Exception as e:
return f"Errore nella generazione: {str(e)}"
def generate_with_timing(text, temp, max_tok):
start_time = time.time()
prompt = f"<s>[INST] {text} [/INST]</s>"
result = generate_response(prompt, temperature=temp, max_tokens=max_tok)
end_time = time.time()
return result, f"{end_time - start_time:.2f} secondi"
# Creazione dell'interfaccia Gradio
def create_gradio_interface():
with gr.Blocks(title="Mistral API") as interface:
gr.Markdown("# Mistral-7B API Server")
with gr.Row():
with gr.Column():
status = gr.Textbox(value=lambda: status_message, label="Stato del modello", interactive=False)
load_button = gr.Button("Carica Modello")
load_button.click(load_model, inputs=[], outputs=[])
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
lines=5,
label="Input",
placeholder="Inserisci il tuo messaggio qui..."
)
with gr.Row():
temp_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperatura"
)
max_token_slider = gr.Slider(
minimum=100,
maximum=MAX_NEW_TOKENS,
value=1024,
step=100,
label="Max Token"
)
submit_button = gr.Button("Genera")
with gr.Column():
output_text = gr.Textbox(lines=12, label="Risposta del modello")
gen_time = gr.Textbox(label="Tempo di generazione", interactive=False)
submit_button.click(
generate_with_timing,
inputs=[input_text, temp_slider, max_token_slider],
outputs=[output_text, gen_time]
)
gr.Markdown("""
## API Endpoint
Questa applicazione espone un endpoint API compatibile con OpenAI:
- `/v1/chat/completions` - Per richieste di completamento chat
- `/status` - Per verificare lo stato del modello
L'endpoint è accessibile dall'URL di questo Hugging Face Space.
""")
return interface
# Crea l'applicazione FastAPI
app = FastAPI()
# Configura CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API endpoint compatibile con OpenAI
@app.post("/v1/chat/completions", response_model=CompletionResponse)
async def create_completion(request: CompletionRequest):
try:
prompt = format_chat_prompt(request.messages)
max_tokens = min(request.max_tokens, MAX_NEW_TOKENS) # Limita i token per evitare OOM
start_time = time.time()
completion_text = generate_response(
prompt,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=max_tokens
)
end_time = time.time()
# Calcola il numero di token (approssimativo)
input_tokens = len(prompt.split())
output_tokens = len(completion_text.split())
response = {
"id": f"chatcmpl-{os.urandom(4).hex()}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": completion_text,
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
}
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# API endpoint per verificare lo stato del modello
@app.get("/status")
async def get_status():
return {"status": status_message, "model": MODEL_PATH}
# Crea l'interfaccia Gradio
demo = create_gradio_interface()
# Monta Gradio su FastAPI usando il metodo corretto per Gradio 4
app = gr.mount_gradio_app(app, demo, path="/")
# Precarica il modello all'avvio (usando il nuovo metodo lifespan invece di on_event)
@app.on_event("startup")
async def startup_load_model():
load_model()
# Per Hugging Face Spaces, assicurati che l'app sia esportata correttamente
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)