MiCkSoftware commited on
Commit
0a33686
·
1 Parent(s): dcf8fc6

change method

Browse files
Files changed (2) hide show
  1. app.py +50 -14
  2. requirements.txt +3 -4
app.py CHANGED
@@ -1,17 +1,53 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
3
 
4
- # Charger un modèle Hugging Face (ici GPT-J)
5
- model_name = "HuggingFaceH4/zephyr-7b-beta"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
- # Fonction pour générer du texte
10
- def generate_text(prompt):
11
- inputs = tokenizer(prompt, return_tensors="pt")
12
- outputs = model.generate(inputs["input_ids"], max_length=100, num_return_sequences=1)
13
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
14
 
15
- # Interface Gradio
16
- iface = gr.Interface(fn=generate_text, inputs="text", outputs="text")
17
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from huggingface_hub import InferenceClient
4
+ from typing import List, Tuple
5
 
6
+ # Initialisation du client Hugging Face
7
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
8
 
9
+ # Initialisation de FastAPI
10
+ app = FastAPI()
 
 
 
11
 
12
+ # Modèle pour les données d'entrée
13
+ class PredictionRequest(BaseModel):
14
+ message: str
15
+ history: List[Tuple[str, str]] = []
16
+ system_message: str = "You are a friendly Chatbot."
17
+ max_tokens: int = 512
18
+ temperature: float = 0.7
19
+ top_p: float = 0.95
20
+
21
+ @app.post("/predict")
22
+ async def predict(request: PredictionRequest):
23
+ """
24
+ Endpoint REST pour effectuer une prédiction.
25
+ """
26
+ # Préparer les messages pour l'inférence
27
+ messages = [{"role": "system", "content": request.system_message}]
28
+ for user_input, assistant_response in request.history:
29
+ if user_input:
30
+ messages.append({"role": "user", "content": user_input})
31
+ if assistant_response:
32
+ messages.append({"role": "assistant", "content": assistant_response})
33
+ messages.append({"role": "user", "content": request.message})
34
+
35
+ # Appel de l'API Hugging Face
36
+ try:
37
+ response = ""
38
+ for message in client.chat_completion(
39
+ messages,
40
+ max_tokens=request.max_tokens,
41
+ stream=True,
42
+ temperature=request.temperature,
43
+ top_p=request.top_p,
44
+ ):
45
+ response += message.choices[0].delta.content
46
+ return {"response": response}
47
+ except Exception as e:
48
+ raise HTTPException(status_code=500, detail=str(e))
49
+
50
+ # Pour le test en local
51
+ if __name__ == "__main__":
52
+ import uvicorn
53
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- huggingface_hub==0.25.2
2
- transformers
3
- gradio
4
- torch==1.13.1
 
1
+ fastapi
2
+ uvicorn
3
+ huggingface_hub