smartdigitalsolutions commited on
Commit
502a685
·
verified ·
1 Parent(s): de2a400

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +47 -12
  2. app.py +237 -0
  3. requirements.txt +6 -0
README.md CHANGED
@@ -1,12 +1,47 @@
1
- ---
2
- title: Mistral 7B API
3
- emoji: 🔥
4
- colorFrom: yellow
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.30.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Mistral-7B-API
3
+ emoji: 🤖
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.13.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # Mistral-7B API Server
14
+
15
+ Questo Space fornisce un'API compatibile con OpenAI per il modello Mistral-7B-Instruct-v0.2. L'API è accessibile tramite l'endpoint `/v1/chat/completions`.
16
+
17
+ ## Caratteristiche
18
+ - Versione quantizzata GGUF del modello Mistral-7B-Instruct-v0.2
19
+ - API compatibile con OpenAI
20
+ - Interfaccia di test Gradio per verificare il funzionamento del modello
21
+
22
+ ## API Usage
23
+ ```python
24
+ import requests
25
+ import json
26
+
27
+ headers = {
28
+ "Content-Type": "application/json"
29
+ }
30
+
31
+ data = {
32
+ "model": "mistral-7b-instruct",
33
+ "messages": [
34
+ {"role": "user", "content": "Quali sono le principali città italiane?"}
35
+ ],
36
+ "temperature": 0.7,
37
+ "max_tokens": 1024
38
+ }
39
+
40
+ response = requests.post(
41
+ "https://huggingface.co/spaces/[username]/Mistral-7B-API/v1/chat/completions",
42
+ headers=headers,
43
+ json=data
44
+ )
45
+
46
+ print(json.dumps(response.json(), indent=2))
47
+ ```
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import gradio as gr
5
+ from threading import Lock
6
+ from ctransformers import AutoModelForCausalLM
7
+ import fastapi
8
+ from fastapi import FastAPI, HTTPException, Request
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel
11
+ from typing import List, Optional, Dict, Any
12
+
13
+ # Configurazione FastAPI come backend di Gradio
14
+ app = fastapi.FastAPI()
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"],
18
+ allow_credentials=True,
19
+ allow_methods=["*"],
20
+ allow_headers=["*"],
21
+ )
22
+
23
+ # Configurazione modello
24
+ MODEL_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
25
+ MODEL_FILE = "mistral-7b-instruct-v0.2.Q4_K_M.gguf" # Versione quantizzata per risparmiare memoria
26
+ MODEL_TYPE = "mistral"
27
+ MAX_NEW_TOKENS = 2048
28
+ MODEL_LOCK = Lock() # Per evitare richieste contemporanee che potrebbero causare OOM
29
+
30
+ # Variabili globali
31
+ model = None
32
+ status_message = "Modello non ancora caricato"
33
+
34
+ # Definizioni dei modelli di dati (Pydantic)
35
+ class Message(BaseModel):
36
+ role: str
37
+ content: str
38
+
39
+ class CompletionRequest(BaseModel):
40
+ model: str
41
+ messages: List[Message]
42
+ temperature: Optional[float] = 0.7
43
+ top_p: Optional[float] = 0.95
44
+ max_tokens: Optional[int] = 2048
45
+ stream: Optional[bool] = False
46
+ stop: Optional[List[str]] = None
47
+
48
+ class CompletionResponse(BaseModel):
49
+ id: str
50
+ object: str = "chat.completion"
51
+ created: int
52
+ model: str
53
+ choices: List[Dict[str, Any]]
54
+ usage: Dict[str, int]
55
+
56
+ # Funzioni di utilità
57
+ def format_chat_prompt(messages: List[Message]) -> str:
58
+ """Formatta i messaggi nel formato atteso da Mistral Instruct."""
59
+ conversation = []
60
+
61
+ for message in messages:
62
+ if message.role == "system":
63
+ # Inserisce il messaggio di sistema come istruzione iniziale
64
+ conversation.append(f"<s>[INST] {message.content} [/INST]</s>")
65
+ elif message.role == "user":
66
+ conversation.append(f"<s>[INST] {message.content} [/INST]</s>")
67
+ elif message.role == "assistant":
68
+ conversation.append(f"<s>{message.content}</s>")
69
+
70
+ return "".join(conversation)
71
+
72
+ def load_model():
73
+ """Carica il modello Mistral quantizzato."""
74
+ global model, status_message
75
+ try:
76
+ status_message = "Caricamento modello in corso..."
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ MODEL_PATH,
79
+ model_file=MODEL_FILE,
80
+ model_type=MODEL_TYPE,
81
+ context_length=4096,
82
+ threads=4 # Usa 4 thread per lasciare risorse al sistema
83
+ )
84
+ status_message = "Modello caricato con successo"
85
+ return True
86
+ except Exception as e:
87
+ status_message = f"Errore nel caricamento del modello: {str(e)}"
88
+ return False
89
+
90
+ def generate_response(prompt, temperature=0.7, top_p=0.95, max_tokens=MAX_NEW_TOKENS):
91
+ """Genera una risposta dal modello."""
92
+ global model, status_message
93
+
94
+ if model is None:
95
+ if not load_model():
96
+ return status_message
97
+
98
+ with MODEL_LOCK: # Previene richieste parallele che potrebbero causare OOM
99
+ try:
100
+ result = model(
101
+ prompt,
102
+ max_new_tokens=max_tokens,
103
+ temperature=temperature,
104
+ top_p=top_p,
105
+ repetition_penalty=1.1
106
+ )
107
+ return result
108
+ except Exception as e:
109
+ return f"Errore nella generazione: {str(e)}"
110
+
111
+ # API endpoint compatibile con OpenAI
112
+ @app.post("/v1/chat/completions", response_model=CompletionResponse)
113
+ async def create_completion(request: CompletionRequest):
114
+ try:
115
+ prompt = format_chat_prompt(request.messages)
116
+
117
+ max_tokens = min(request.max_tokens, MAX_NEW_TOKENS) # Limita i token per evitare OOM
118
+
119
+ start_time = time.time()
120
+ completion_text = generate_response(
121
+ prompt,
122
+ temperature=request.temperature,
123
+ top_p=request.top_p,
124
+ max_tokens=max_tokens
125
+ )
126
+ end_time = time.time()
127
+
128
+ # Calcola il numero di token (approssimativo)
129
+ input_tokens = len(prompt.split())
130
+ output_tokens = len(completion_text.split())
131
+
132
+ response = {
133
+ "id": f"chatcmpl-{os.urandom(4).hex()}",
134
+ "object": "chat.completion",
135
+ "created": int(time.time()),
136
+ "model": request.model,
137
+ "choices": [
138
+ {
139
+ "index": 0,
140
+ "message": {
141
+ "role": "assistant",
142
+ "content": completion_text,
143
+ },
144
+ "finish_reason": "stop",
145
+ }
146
+ ],
147
+ "usage": {
148
+ "prompt_tokens": input_tokens,
149
+ "completion_tokens": output_tokens,
150
+ "total_tokens": input_tokens + output_tokens,
151
+ }
152
+ }
153
+
154
+ return response
155
+ except Exception as e:
156
+ raise HTTPException(status_code=500, detail=str(e))
157
+
158
+ # API endpoint per verificare lo stato del modello
159
+ @app.get("/status")
160
+ async def get_status():
161
+ return {"status": status_message, "model": MODEL_PATH}
162
+
163
+ # Interfaccia Gradio per testing manuale
164
+ def create_gradio_interface():
165
+ with gr.Blocks(title="Mistral API") as interface:
166
+ gr.Markdown("# Mistral-7B API Server")
167
+
168
+ with gr.Row():
169
+ with gr.Column():
170
+ status = gr.Textbox(value=lambda: status_message, label="Stato del modello", interactive=False)
171
+ load_button = gr.Button("Carica Modello")
172
+ load_button.click(load_model, inputs=[], outputs=[])
173
+
174
+ with gr.Row():
175
+ with gr.Column():
176
+ input_text = gr.Textbox(
177
+ lines=5,
178
+ label="Input",
179
+ placeholder="Inserisci il tuo messaggio qui..."
180
+ )
181
+
182
+ with gr.Row():
183
+ temp_slider = gr.Slider(
184
+ minimum=0.1,
185
+ maximum=1.0,
186
+ value=0.7,
187
+ step=0.1,
188
+ label="Temperatura"
189
+ )
190
+
191
+ max_token_slider = gr.Slider(
192
+ minimum=100,
193
+ maximum=MAX_NEW_TOKENS,
194
+ value=1024,
195
+ step=100,
196
+ label="Max Token"
197
+ )
198
+
199
+ submit_button = gr.Button("Genera")
200
+
201
+ with gr.Column():
202
+ output_text = gr.Textbox(lines=12, label="Risposta del modello")
203
+
204
+ gen_time = gr.Textbox(label="Tempo di generazione", interactive=False)
205
+
206
+ def generate_with_timing(text, temp, max_tok):
207
+ start_time = time.time()
208
+ prompt = f"<s>[INST] {text} [/INST]</s>"
209
+ result = generate_response(prompt, temperature=temp, max_tokens=max_tok)
210
+ end_time = time.time()
211
+ return result, f"{end_time - start_time:.2f} secondi"
212
+
213
+ submit_button.click(
214
+ generate_with_timing,
215
+ inputs=[input_text, temp_slider, max_token_slider],
216
+ outputs=[output_text, gen_time]
217
+ )
218
+
219
+ gr.Markdown("""
220
+ ## API Endpoint
221
+ Questa applicazione espone un endpoint API compatibile con OpenAI:
222
+ - `/v1/chat/completions` - Per richieste di completamento chat
223
+ - `/status` - Per verificare lo stato del modello
224
+
225
+ L'endpoint è accessibile dall'URL di questo Hugging Face Space.
226
+ """)
227
+
228
+ return interface
229
+
230
+ # Inizializza e avvia l'app Gradio
231
+ demo = create_gradio_interface()
232
+ app = gr.mount_gradio_app(app, demo, path="/")
233
+
234
+ # Precarica il modello al primo avvio
235
+ @app.on_event("startup")
236
+ async def startup_load_model():
237
+ load_model()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ accelerate==0.25.0
2
+ fastapi==0.105.0
3
+ gradio==4.13.0
4
+ pydantic==2.5.0
5
+ ctransformers[cuda]==0.2.27
6
+ uvicorn==0.24.0