| import os |
| import time |
| import json |
| import gradio as gr |
| from threading import Lock |
| from ctransformers import AutoModelForCausalLM |
| import fastapi |
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from typing import List, Optional, Dict, Any |
|
|
| |
| model = None |
| status_message = "Modello non ancora caricato" |
| MODEL_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" |
| MODEL_FILE = "mistral-7b-instruct-v0.2.Q4_K_M.gguf" |
| MODEL_TYPE = "mistral" |
| MAX_NEW_TOKENS = 2048 |
| MODEL_LOCK = Lock() |
|
|
| |
| class Message(BaseModel): |
| role: str |
| content: str |
|
|
| class CompletionRequest(BaseModel): |
| model: str |
| messages: List[Message] |
| temperature: Optional[float] = 0.7 |
| top_p: Optional[float] = 0.95 |
| max_tokens: Optional[int] = 2048 |
| stream: Optional[bool] = False |
| stop: Optional[List[str]] = None |
|
|
| class CompletionResponse(BaseModel): |
| id: str |
| object: str = "chat.completion" |
| created: int |
| model: str |
| choices: List[Dict[str, Any]] |
| usage: Dict[str, int] |
|
|
| |
| def format_chat_prompt(messages: List[Message]) -> str: |
| """Formatta i messaggi nel formato atteso da Mistral Instruct.""" |
| conversation = [] |
| |
| for message in messages: |
| if message.role == "system": |
| |
| conversation.append(f"<s>[INST] {message.content} [/INST]</s>") |
| elif message.role == "user": |
| conversation.append(f"<s>[INST] {message.content} [/INST]</s>") |
| elif message.role == "assistant": |
| conversation.append(f"<s>{message.content}</s>") |
| |
| return "".join(conversation) |
|
|
| def load_model(): |
| """Carica il modello Mistral quantizzato.""" |
| global model, status_message |
| try: |
| status_message = "Caricamento modello in corso..." |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_PATH, |
| model_file=MODEL_FILE, |
| model_type=MODEL_TYPE, |
| context_length=4096, |
| threads=4 |
| ) |
| status_message = "Modello caricato con successo" |
| return True |
| except Exception as e: |
| status_message = f"Errore nel caricamento del modello: {str(e)}" |
| return False |
|
|
| def generate_response(prompt, temperature=0.7, top_p=0.95, max_tokens=MAX_NEW_TOKENS): |
| """Genera una risposta dal modello.""" |
| global model, status_message |
| |
| if model is None: |
| if not load_model(): |
| return status_message |
| |
| with MODEL_LOCK: |
| try: |
| result = model( |
| prompt, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| repetition_penalty=1.1 |
| ) |
| return result |
| except Exception as e: |
| return f"Errore nella generazione: {str(e)}" |
|
|
| def generate_with_timing(text, temp, max_tok): |
| start_time = time.time() |
| prompt = f"<s>[INST] {text} [/INST]</s>" |
| result = generate_response(prompt, temperature=temp, max_tokens=max_tok) |
| end_time = time.time() |
| return result, f"{end_time - start_time:.2f} secondi" |
|
|
| |
| def create_gradio_interface(): |
| with gr.Blocks(title="Mistral API") as interface: |
| gr.Markdown("# Mistral-7B API Server") |
| |
| with gr.Row(): |
| with gr.Column(): |
| status = gr.Textbox(value=lambda: status_message, label="Stato del modello", interactive=False) |
| load_button = gr.Button("Carica Modello") |
| load_button.click(load_model, inputs=[], outputs=[]) |
| |
| with gr.Row(): |
| with gr.Column(): |
| input_text = gr.Textbox( |
| lines=5, |
| label="Input", |
| placeholder="Inserisci il tuo messaggio qui..." |
| ) |
| |
| with gr.Row(): |
| temp_slider = gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.7, |
| step=0.1, |
| label="Temperatura" |
| ) |
| |
| max_token_slider = gr.Slider( |
| minimum=100, |
| maximum=MAX_NEW_TOKENS, |
| value=1024, |
| step=100, |
| label="Max Token" |
| ) |
| |
| submit_button = gr.Button("Genera") |
| |
| with gr.Column(): |
| output_text = gr.Textbox(lines=12, label="Risposta del modello") |
| |
| gen_time = gr.Textbox(label="Tempo di generazione", interactive=False) |
| |
| submit_button.click( |
| generate_with_timing, |
| inputs=[input_text, temp_slider, max_token_slider], |
| outputs=[output_text, gen_time] |
| ) |
| |
| gr.Markdown(""" |
| ## API Endpoint |
| Questa applicazione espone un endpoint API compatibile con OpenAI: |
| - `/v1/chat/completions` - Per richieste di completamento chat |
| - `/status` - Per verificare lo stato del modello |
| |
| L'endpoint è accessibile dall'URL di questo Hugging Face Space. |
| """) |
| |
| return interface |
|
|
| |
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| @app.post("/v1/chat/completions", response_model=CompletionResponse) |
| async def create_completion(request: CompletionRequest): |
| try: |
| prompt = format_chat_prompt(request.messages) |
| |
| max_tokens = min(request.max_tokens, MAX_NEW_TOKENS) |
| |
| start_time = time.time() |
| completion_text = generate_response( |
| prompt, |
| temperature=request.temperature, |
| top_p=request.top_p, |
| max_tokens=max_tokens |
| ) |
| end_time = time.time() |
| |
| |
| input_tokens = len(prompt.split()) |
| output_tokens = len(completion_text.split()) |
| |
| response = { |
| "id": f"chatcmpl-{os.urandom(4).hex()}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": request.model, |
| "choices": [ |
| { |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": completion_text, |
| }, |
| "finish_reason": "stop", |
| } |
| ], |
| "usage": { |
| "prompt_tokens": input_tokens, |
| "completion_tokens": output_tokens, |
| "total_tokens": input_tokens + output_tokens, |
| } |
| } |
| |
| return response |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| @app.get("/status") |
| async def get_status(): |
| return {"status": status_message, "model": MODEL_PATH} |
|
|
| |
| demo = create_gradio_interface() |
|
|
| |
| app = gr.mount_gradio_app(app, demo, path="/") |
|
|
| |
| @app.on_event("startup") |
| async def startup_load_model(): |
| load_model() |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |