smartdigitalsolutions commited on
Commit
4eec20e
·
verified ·
1 Parent(s): c8587d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -237
app.py CHANGED
@@ -1,237 +1,244 @@
1
- import os
2
- import time
3
- import json
4
- import gradio as gr
5
- from threading import Lock
6
- from ctransformers import AutoModelForCausalLM
7
- import fastapi
8
- from fastapi import FastAPI, HTTPException, Request
9
- from fastapi.middleware.cors import CORSMiddleware
10
- from pydantic import BaseModel
11
- from typing import List, Optional, Dict, Any
12
-
13
- # Configurazione FastAPI come backend di Gradio
14
- app = fastapi.FastAPI()
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["*"],
18
- allow_credentials=True,
19
- allow_methods=["*"],
20
- allow_headers=["*"],
21
- )
22
-
23
- # Configurazione modello
24
- MODEL_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
25
- MODEL_FILE = "mistral-7b-instruct-v0.2.Q4_K_M.gguf" # Versione quantizzata per risparmiare memoria
26
- MODEL_TYPE = "mistral"
27
- MAX_NEW_TOKENS = 2048
28
- MODEL_LOCK = Lock() # Per evitare richieste contemporanee che potrebbero causare OOM
29
-
30
- # Variabili globali
31
- model = None
32
- status_message = "Modello non ancora caricato"
33
-
34
- # Definizioni dei modelli di dati (Pydantic)
35
- class Message(BaseModel):
36
- role: str
37
- content: str
38
-
39
- class CompletionRequest(BaseModel):
40
- model: str
41
- messages: List[Message]
42
- temperature: Optional[float] = 0.7
43
- top_p: Optional[float] = 0.95
44
- max_tokens: Optional[int] = 2048
45
- stream: Optional[bool] = False
46
- stop: Optional[List[str]] = None
47
-
48
- class CompletionResponse(BaseModel):
49
- id: str
50
- object: str = "chat.completion"
51
- created: int
52
- model: str
53
- choices: List[Dict[str, Any]]
54
- usage: Dict[str, int]
55
-
56
- # Funzioni di utilità
57
- def format_chat_prompt(messages: List[Message]) -> str:
58
- """Formatta i messaggi nel formato atteso da Mistral Instruct."""
59
- conversation = []
60
-
61
- for message in messages:
62
- if message.role == "system":
63
- # Inserisce il messaggio di sistema come istruzione iniziale
64
- conversation.append(f"<s>[INST] {message.content} [/INST]</s>")
65
- elif message.role == "user":
66
- conversation.append(f"<s>[INST] {message.content} [/INST]</s>")
67
- elif message.role == "assistant":
68
- conversation.append(f"<s>{message.content}</s>")
69
-
70
- return "".join(conversation)
71
-
72
- def load_model():
73
- """Carica il modello Mistral quantizzato."""
74
- global model, status_message
75
- try:
76
- status_message = "Caricamento modello in corso..."
77
- model = AutoModelForCausalLM.from_pretrained(
78
- MODEL_PATH,
79
- model_file=MODEL_FILE,
80
- model_type=MODEL_TYPE,
81
- context_length=4096,
82
- threads=4 # Usa 4 thread per lasciare risorse al sistema
83
- )
84
- status_message = "Modello caricato con successo"
85
- return True
86
- except Exception as e:
87
- status_message = f"Errore nel caricamento del modello: {str(e)}"
88
- return False
89
-
90
- def generate_response(prompt, temperature=0.7, top_p=0.95, max_tokens=MAX_NEW_TOKENS):
91
- """Genera una risposta dal modello."""
92
- global model, status_message
93
-
94
- if model is None:
95
- if not load_model():
96
- return status_message
97
-
98
- with MODEL_LOCK: # Previene richieste parallele che potrebbero causare OOM
99
- try:
100
- result = model(
101
- prompt,
102
- max_new_tokens=max_tokens,
103
- temperature=temperature,
104
- top_p=top_p,
105
- repetition_penalty=1.1
106
- )
107
- return result
108
- except Exception as e:
109
- return f"Errore nella generazione: {str(e)}"
110
-
111
- # API endpoint compatibile con OpenAI
112
- @app.post("/v1/chat/completions", response_model=CompletionResponse)
113
- async def create_completion(request: CompletionRequest):
114
- try:
115
- prompt = format_chat_prompt(request.messages)
116
-
117
- max_tokens = min(request.max_tokens, MAX_NEW_TOKENS) # Limita i token per evitare OOM
118
-
119
- start_time = time.time()
120
- completion_text = generate_response(
121
- prompt,
122
- temperature=request.temperature,
123
- top_p=request.top_p,
124
- max_tokens=max_tokens
125
- )
126
- end_time = time.time()
127
-
128
- # Calcola il numero di token (approssimativo)
129
- input_tokens = len(prompt.split())
130
- output_tokens = len(completion_text.split())
131
-
132
- response = {
133
- "id": f"chatcmpl-{os.urandom(4).hex()}",
134
- "object": "chat.completion",
135
- "created": int(time.time()),
136
- "model": request.model,
137
- "choices": [
138
- {
139
- "index": 0,
140
- "message": {
141
- "role": "assistant",
142
- "content": completion_text,
143
- },
144
- "finish_reason": "stop",
145
- }
146
- ],
147
- "usage": {
148
- "prompt_tokens": input_tokens,
149
- "completion_tokens": output_tokens,
150
- "total_tokens": input_tokens + output_tokens,
151
- }
152
- }
153
-
154
- return response
155
- except Exception as e:
156
- raise HTTPException(status_code=500, detail=str(e))
157
-
158
- # API endpoint per verificare lo stato del modello
159
- @app.get("/status")
160
- async def get_status():
161
- return {"status": status_message, "model": MODEL_PATH}
162
-
163
- # Interfaccia Gradio per testing manuale
164
- def create_gradio_interface():
165
- with gr.Blocks(title="Mistral API") as interface:
166
- gr.Markdown("# Mistral-7B API Server")
167
-
168
- with gr.Row():
169
- with gr.Column():
170
- status = gr.Textbox(value=lambda: status_message, label="Stato del modello", interactive=False)
171
- load_button = gr.Button("Carica Modello")
172
- load_button.click(load_model, inputs=[], outputs=[])
173
-
174
- with gr.Row():
175
- with gr.Column():
176
- input_text = gr.Textbox(
177
- lines=5,
178
- label="Input",
179
- placeholder="Inserisci il tuo messaggio qui..."
180
- )
181
-
182
- with gr.Row():
183
- temp_slider = gr.Slider(
184
- minimum=0.1,
185
- maximum=1.0,
186
- value=0.7,
187
- step=0.1,
188
- label="Temperatura"
189
- )
190
-
191
- max_token_slider = gr.Slider(
192
- minimum=100,
193
- maximum=MAX_NEW_TOKENS,
194
- value=1024,
195
- step=100,
196
- label="Max Token"
197
- )
198
-
199
- submit_button = gr.Button("Genera")
200
-
201
- with gr.Column():
202
- output_text = gr.Textbox(lines=12, label="Risposta del modello")
203
-
204
- gen_time = gr.Textbox(label="Tempo di generazione", interactive=False)
205
-
206
- def generate_with_timing(text, temp, max_tok):
207
- start_time = time.time()
208
- prompt = f"<s>[INST] {text} [/INST]</s>"
209
- result = generate_response(prompt, temperature=temp, max_tokens=max_tok)
210
- end_time = time.time()
211
- return result, f"{end_time - start_time:.2f} secondi"
212
-
213
- submit_button.click(
214
- generate_with_timing,
215
- inputs=[input_text, temp_slider, max_token_slider],
216
- outputs=[output_text, gen_time]
217
- )
218
-
219
- gr.Markdown("""
220
- ## API Endpoint
221
- Questa applicazione espone un endpoint API compatibile con OpenAI:
222
- - `/v1/chat/completions` - Per richieste di completamento chat
223
- - `/status` - Per verificare lo stato del modello
224
-
225
- L'endpoint è accessibile dall'URL di questo Hugging Face Space.
226
- """)
227
-
228
- return interface
229
-
230
- # Inizializza e avvia l'app Gradio
231
- demo = create_gradio_interface()
232
- app = gr.mount_gradio_app(app, demo, path="/")
233
-
234
- # Precarica il modello al primo avvio
235
- @app.on_event("startup")
236
- async def startup_load_model():
237
- load_model()
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import gradio as gr
5
+ from threading import Lock
6
+ from ctransformers import AutoModelForCausalLM
7
+ import fastapi
8
+ from fastapi import FastAPI, HTTPException, Request
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel
11
+ from typing import List, Optional, Dict, Any
12
+
13
+ # Variabili globali
14
+ model = None
15
+ status_message = "Modello non ancora caricato"
16
+ MODEL_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
17
+ MODEL_FILE = "mistral-7b-instruct-v0.2.Q4_K_M.gguf" # Versione quantizzata per risparmiare memoria
18
+ MODEL_TYPE = "mistral"
19
+ MAX_NEW_TOKENS = 2048
20
+ MODEL_LOCK = Lock() # Per evitare richieste contemporanee che potrebbero causare OOM
21
+
22
+ # Definizioni dei modelli di dati (Pydantic)
23
+ class Message(BaseModel):
24
+ role: str
25
+ content: str
26
+
27
+ class CompletionRequest(BaseModel):
28
+ model: str
29
+ messages: List[Message]
30
+ temperature: Optional[float] = 0.7
31
+ top_p: Optional[float] = 0.95
32
+ max_tokens: Optional[int] = 2048
33
+ stream: Optional[bool] = False
34
+ stop: Optional[List[str]] = None
35
+
36
+ class CompletionResponse(BaseModel):
37
+ id: str
38
+ object: str = "chat.completion"
39
+ created: int
40
+ model: str
41
+ choices: List[Dict[str, Any]]
42
+ usage: Dict[str, int]
43
+
44
+ # Funzioni di utilità
45
+ def format_chat_prompt(messages: List[Message]) -> str:
46
+ """Formatta i messaggi nel formato atteso da Mistral Instruct."""
47
+ conversation = []
48
+
49
+ for message in messages:
50
+ if message.role == "system":
51
+ # Inserisce il messaggio di sistema come istruzione iniziale
52
+ conversation.append(f"<s>[INST] {message.content} [/INST]</s>")
53
+ elif message.role == "user":
54
+ conversation.append(f"<s>[INST] {message.content} [/INST]</s>")
55
+ elif message.role == "assistant":
56
+ conversation.append(f"<s>{message.content}</s>")
57
+
58
+ return "".join(conversation)
59
+
60
+ def load_model():
61
+ """Carica il modello Mistral quantizzato."""
62
+ global model, status_message
63
+ try:
64
+ status_message = "Caricamento modello in corso..."
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ MODEL_PATH,
67
+ model_file=MODEL_FILE,
68
+ model_type=MODEL_TYPE,
69
+ context_length=4096,
70
+ threads=4 # Usa 4 thread per lasciare risorse al sistema
71
+ )
72
+ status_message = "Modello caricato con successo"
73
+ return True
74
+ except Exception as e:
75
+ status_message = f"Errore nel caricamento del modello: {str(e)}"
76
+ return False
77
+
78
+ def generate_response(prompt, temperature=0.7, top_p=0.95, max_tokens=MAX_NEW_TOKENS):
79
+ """Genera una risposta dal modello."""
80
+ global model, status_message
81
+
82
+ if model is None:
83
+ if not load_model():
84
+ return status_message
85
+
86
+ with MODEL_LOCK: # Previene richieste parallele che potrebbero causare OOM
87
+ try:
88
+ result = model(
89
+ prompt,
90
+ max_new_tokens=max_tokens,
91
+ temperature=temperature,
92
+ top_p=top_p,
93
+ repetition_penalty=1.1
94
+ )
95
+ return result
96
+ except Exception as e:
97
+ return f"Errore nella generazione: {str(e)}"
98
+
99
+ def generate_with_timing(text, temp, max_tok):
100
+ start_time = time.time()
101
+ prompt = f"<s>[INST] {text} [/INST]</s>"
102
+ result = generate_response(prompt, temperature=temp, max_tokens=max_tok)
103
+ end_time = time.time()
104
+ return result, f"{end_time - start_time:.2f} secondi"
105
+
106
+ # Creazione dell'interfaccia Gradio
107
+ def create_gradio_interface():
108
+ with gr.Blocks(title="Mistral API") as interface:
109
+ gr.Markdown("# Mistral-7B API Server")
110
+
111
+ with gr.Row():
112
+ with gr.Column():
113
+ status = gr.Textbox(value=lambda: status_message, label="Stato del modello", interactive=False)
114
+ load_button = gr.Button("Carica Modello")
115
+ load_button.click(load_model, inputs=[], outputs=[])
116
+
117
+ with gr.Row():
118
+ with gr.Column():
119
+ input_text = gr.Textbox(
120
+ lines=5,
121
+ label="Input",
122
+ placeholder="Inserisci il tuo messaggio qui..."
123
+ )
124
+
125
+ with gr.Row():
126
+ temp_slider = gr.Slider(
127
+ minimum=0.1,
128
+ maximum=1.0,
129
+ value=0.7,
130
+ step=0.1,
131
+ label="Temperatura"
132
+ )
133
+
134
+ max_token_slider = gr.Slider(
135
+ minimum=100,
136
+ maximum=MAX_NEW_TOKENS,
137
+ value=1024,
138
+ step=100,
139
+ label="Max Token"
140
+ )
141
+
142
+ submit_button = gr.Button("Genera")
143
+
144
+ with gr.Column():
145
+ output_text = gr.Textbox(lines=12, label="Risposta del modello")
146
+
147
+ gen_time = gr.Textbox(label="Tempo di generazione", interactive=False)
148
+
149
+ submit_button.click(
150
+ generate_with_timing,
151
+ inputs=[input_text, temp_slider, max_token_slider],
152
+ outputs=[output_text, gen_time]
153
+ )
154
+
155
+ gr.Markdown("""
156
+ ## API Endpoint
157
+ Questa applicazione espone un endpoint API compatibile con OpenAI:
158
+ - `/v1/chat/completions` - Per richieste di completamento chat
159
+ - `/status` - Per verificare lo stato del modello
160
+
161
+ L'endpoint è accessibile dall'URL di questo Hugging Face Space.
162
+ """)
163
+
164
+ return interface
165
+
166
+ # Crea l'applicazione FastAPI
167
+ app = FastAPI()
168
+
169
+ # Configura CORS
170
+ app.add_middleware(
171
+ CORSMiddleware,
172
+ allow_origins=["*"],
173
+ allow_credentials=True,
174
+ allow_methods=["*"],
175
+ allow_headers=["*"],
176
+ )
177
+
178
+ # API endpoint compatibile con OpenAI
179
+ @app.post("/v1/chat/completions", response_model=CompletionResponse)
180
+ async def create_completion(request: CompletionRequest):
181
+ try:
182
+ prompt = format_chat_prompt(request.messages)
183
+
184
+ max_tokens = min(request.max_tokens, MAX_NEW_TOKENS) # Limita i token per evitare OOM
185
+
186
+ start_time = time.time()
187
+ completion_text = generate_response(
188
+ prompt,
189
+ temperature=request.temperature,
190
+ top_p=request.top_p,
191
+ max_tokens=max_tokens
192
+ )
193
+ end_time = time.time()
194
+
195
+ # Calcola il numero di token (approssimativo)
196
+ input_tokens = len(prompt.split())
197
+ output_tokens = len(completion_text.split())
198
+
199
+ response = {
200
+ "id": f"chatcmpl-{os.urandom(4).hex()}",
201
+ "object": "chat.completion",
202
+ "created": int(time.time()),
203
+ "model": request.model,
204
+ "choices": [
205
+ {
206
+ "index": 0,
207
+ "message": {
208
+ "role": "assistant",
209
+ "content": completion_text,
210
+ },
211
+ "finish_reason": "stop",
212
+ }
213
+ ],
214
+ "usage": {
215
+ "prompt_tokens": input_tokens,
216
+ "completion_tokens": output_tokens,
217
+ "total_tokens": input_tokens + output_tokens,
218
+ }
219
+ }
220
+
221
+ return response
222
+ except Exception as e:
223
+ raise HTTPException(status_code=500, detail=str(e))
224
+
225
+ # API endpoint per verificare lo stato del modello
226
+ @app.get("/status")
227
+ async def get_status():
228
+ return {"status": status_message, "model": MODEL_PATH}
229
+
230
+ # Crea l'interfaccia Gradio
231
+ demo = create_gradio_interface()
232
+
233
+ # Monta Gradio su FastAPI usando il metodo corretto per Gradio 4
234
+ app = gr.mount_gradio_app(app, demo, path="/")
235
+
236
+ # Precarica il modello all'avvio (usando il nuovo metodo lifespan invece di on_event)
237
+ @app.on_event("startup")
238
+ async def startup_load_model():
239
+ load_model()
240
+
241
+ # Per Hugging Face Spaces, assicurati che l'app sia esportata correttamente
242
+ if __name__ == "__main__":
243
+ import uvicorn
244
+ uvicorn.run(app, host="0.0.0.0", port=7860)