Files changed (1) hide show
  1. app.py +0 -237
app.py DELETED
@@ -1,237 +0,0 @@
1
- import os
2
- import time
3
- import json
4
- import gradio as gr
5
- from threading import Lock
6
- from ctransformers import AutoModelForCausalLM
7
- import fastapi
8
- from fastapi import FastAPI, HTTPException, Request
9
- from fastapi.middleware.cors import CORSMiddleware
10
- from pydantic import BaseModel
11
- from typing import List, Optional, Dict, Any
12
-
13
- # Configurazione FastAPI come backend di Gradio
14
- app = fastapi.FastAPI()
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["*"],
18
- allow_credentials=True,
19
- allow_methods=["*"],
20
- allow_headers=["*"],
21
- )
22
-
23
- # Configurazione modello
24
- MODEL_PATH = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
25
- MODEL_FILE = "mistral-7b-instruct-v0.2.Q4_K_M.gguf" # Versione quantizzata per risparmiare memoria
26
- MODEL_TYPE = "mistral"
27
- MAX_NEW_TOKENS = 2048
28
- MODEL_LOCK = Lock() # Per evitare richieste contemporanee che potrebbero causare OOM
29
-
30
- # Variabili globali
31
- model = None
32
- status_message = "Modello non ancora caricato"
33
-
34
- # Definizioni dei modelli di dati (Pydantic)
35
- class Message(BaseModel):
36
- role: str
37
- content: str
38
-
39
- class CompletionRequest(BaseModel):
40
- model: str
41
- messages: List[Message]
42
- temperature: Optional[float] = 0.7
43
- top_p: Optional[float] = 0.95
44
- max_tokens: Optional[int] = 2048
45
- stream: Optional[bool] = False
46
- stop: Optional[List[str]] = None
47
-
48
- class CompletionResponse(BaseModel):
49
- id: str
50
- object: str = "chat.completion"
51
- created: int
52
- model: str
53
- choices: List[Dict[str, Any]]
54
- usage: Dict[str, int]
55
-
56
- # Funzioni di utilità
57
- def format_chat_prompt(messages: List[Message]) -> str:
58
- """Formatta i messaggi nel formato atteso da Mistral Instruct."""
59
- conversation = []
60
-
61
- for message in messages:
62
- if message.role == "system":
63
- # Inserisce il messaggio di sistema come istruzione iniziale
64
- conversation.append(f"<s>[INST] {message.content} [/INST]</s>")
65
- elif message.role == "user":
66
- conversation.append(f"<s>[INST] {message.content} [/INST]</s>")
67
- elif message.role == "assistant":
68
- conversation.append(f"<s>{message.content}</s>")
69
-
70
- return "".join(conversation)
71
-
72
- def load_model():
73
- """Carica il modello Mistral quantizzato."""
74
- global model, status_message
75
- try:
76
- status_message = "Caricamento modello in corso..."
77
- model = AutoModelForCausalLM.from_pretrained(
78
- MODEL_PATH,
79
- model_file=MODEL_FILE,
80
- model_type=MODEL_TYPE,
81
- context_length=4096,
82
- threads=4 # Usa 4 thread per lasciare risorse al sistema
83
- )
84
- status_message = "Modello caricato con successo"
85
- return True
86
- except Exception as e:
87
- status_message = f"Errore nel caricamento del modello: {str(e)}"
88
- return False
89
-
90
- def generate_response(prompt, temperature=0.7, top_p=0.95, max_tokens=MAX_NEW_TOKENS):
91
- """Genera una risposta dal modello."""
92
- global model, status_message
93
-
94
- if model is None:
95
- if not load_model():
96
- return status_message
97
-
98
- with MODEL_LOCK: # Previene richieste parallele che potrebbero causare OOM
99
- try:
100
- result = model(
101
- prompt,
102
- max_new_tokens=max_tokens,
103
- temperature=temperature,
104
- top_p=top_p,
105
- repetition_penalty=1.1
106
- )
107
- return result
108
- except Exception as e:
109
- return f"Errore nella generazione: {str(e)}"
110
-
111
- # API endpoint compatibile con OpenAI
112
- @app.post("/v1/chat/completions", response_model=CompletionResponse)
113
- async def create_completion(request: CompletionRequest):
114
- try:
115
- prompt = format_chat_prompt(request.messages)
116
-
117
- max_tokens = min(request.max_tokens, MAX_NEW_TOKENS) # Limita i token per evitare OOM
118
-
119
- start_time = time.time()
120
- completion_text = generate_response(
121
- prompt,
122
- temperature=request.temperature,
123
- top_p=request.top_p,
124
- max_tokens=max_tokens
125
- )
126
- end_time = time.time()
127
-
128
- # Calcola il numero di token (approssimativo)
129
- input_tokens = len(prompt.split())
130
- output_tokens = len(completion_text.split())
131
-
132
- response = {
133
- "id": f"chatcmpl-{os.urandom(4).hex()}",
134
- "object": "chat.completion",
135
- "created": int(time.time()),
136
- "model": request.model,
137
- "choices": [
138
- {
139
- "index": 0,
140
- "message": {
141
- "role": "assistant",
142
- "content": completion_text,
143
- },
144
- "finish_reason": "stop",
145
- }
146
- ],
147
- "usage": {
148
- "prompt_tokens": input_tokens,
149
- "completion_tokens": output_tokens,
150
- "total_tokens": input_tokens + output_tokens,
151
- }
152
- }
153
-
154
- return response
155
- except Exception as e:
156
- raise HTTPException(status_code=500, detail=str(e))
157
-
158
- # API endpoint per verificare lo stato del modello
159
- @app.get("/status")
160
- async def get_status():
161
- return {"status": status_message, "model": MODEL_PATH}
162
-
163
- # Interfaccia Gradio per testing manuale
164
- def create_gradio_interface():
165
- with gr.Blocks(title="Mistral API") as interface:
166
- gr.Markdown("# Mistral-7B API Server")
167
-
168
- with gr.Row():
169
- with gr.Column():
170
- status = gr.Textbox(value=lambda: status_message, label="Stato del modello", interactive=False)
171
- load_button = gr.Button("Carica Modello")
172
- load_button.click(load_model, inputs=[], outputs=[])
173
-
174
- with gr.Row():
175
- with gr.Column():
176
- input_text = gr.Textbox(
177
- lines=5,
178
- label="Input",
179
- placeholder="Inserisci il tuo messaggio qui..."
180
- )
181
-
182
- with gr.Row():
183
- temp_slider = gr.Slider(
184
- minimum=0.1,
185
- maximum=1.0,
186
- value=0.7,
187
- step=0.1,
188
- label="Temperatura"
189
- )
190
-
191
- max_token_slider = gr.Slider(
192
- minimum=100,
193
- maximum=MAX_NEW_TOKENS,
194
- value=1024,
195
- step=100,
196
- label="Max Token"
197
- )
198
-
199
- submit_button = gr.Button("Genera")
200
-
201
- with gr.Column():
202
- output_text = gr.Textbox(lines=12, label="Risposta del modello")
203
-
204
- gen_time = gr.Textbox(label="Tempo di generazione", interactive=False)
205
-
206
- def generate_with_timing(text, temp, max_tok):
207
- start_time = time.time()
208
- prompt = f"<s>[INST] {text} [/INST]</s>"
209
- result = generate_response(prompt, temperature=temp, max_tokens=max_tok)
210
- end_time = time.time()
211
- return result, f"{end_time - start_time:.2f} secondi"
212
-
213
- submit_button.click(
214
- generate_with_timing,
215
- inputs=[input_text, temp_slider, max_token_slider],
216
- outputs=[output_text, gen_time]
217
- )
218
-
219
- gr.Markdown("""
220
- ## API Endpoint
221
- Questa applicazione espone un endpoint API compatibile con OpenAI:
222
- - `/v1/chat/completions` - Per richieste di completamento chat
223
- - `/status` - Per verificare lo stato del modello
224
-
225
- L'endpoint è accessibile dall'URL di questo Hugging Face Space.
226
- """)
227
-
228
- return interface
229
-
230
- # Inizializza e avvia l'app Gradio
231
- demo = create_gradio_interface()
232
- app = gr.mount_gradio_app(app, demo, path="/")
233
-
234
- # Precarica il modello al primo avvio
235
- @app.on_event("startup")
236
- async def startup_load_model():
237
- load_model()