Shrees0507 commited on
Commit
7300c08
·
verified ·
1 Parent(s): 8520ff5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -20
app.py CHANGED
@@ -1,33 +1,64 @@
1
- from fastapi import FastAPI, Request
2
- from transformers import pipeline
3
- import json
4
  import os
 
5
 
6
  app = FastAPI()
7
 
8
- # Set your Hugging Face API token
9
- huggingface_token = os.environ.get("huggingface_token")
10
 
11
- # Load the model
12
- generator = pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B")
 
13
 
14
- conversation_history = []
 
 
15
 
16
  @app.post("/chat")
17
- async def chat(message: dict):
18
- global conversation_history
19
- # Extract user message
20
- user_message = message["message"]
21
- conversation_history.append({"role": "user", "content": user_message})
 
 
 
 
 
 
 
22
 
23
- messages = [msg["content"] for msg in conversation_history]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # Add a system message to instruct the model
26
- messages.insert(0, "You are a story writer. Whatever the prompt, you always write a short story of 30 words.")
 
 
27
 
28
- # Generate response using Hugging Face model
29
- reply = generator(messages, max_length=30, do_sample=False)[0]["generated_text"]
 
30
 
31
- conversation_history.append({"role": "assistant", "content": reply})
 
32
 
33
- return reply
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from huggingface_hub import InferenceClient
4
  import os
5
+ import gradio as gr
6
 
7
  app = FastAPI()
8
 
9
+ client = InferenceClient(model="HuggingFaceH4/zephyr-7b-beta", token=os.getenv("huggingface_token"))
 
10
 
11
+ class Message(BaseModel):
12
+ input: str
13
+ history: list
14
 
15
+ @app.get("/")
16
+ async def home():
17
+ return {"message": "Welcome to the chatbot API!"}
18
 
19
  @app.post("/chat")
20
+ async def chat(message: Message):
21
+ try:
22
+ input_message = message.input
23
+ history = message.history
24
+
25
+ # Call the chatbot function
26
+ response = chatbot(input_message, history)
27
+
28
+ return {"response": response}
29
+
30
+ except Exception as e:
31
+ raise HTTPException(status_code=500, detail=str(e))
32
 
33
+ def chatbot(input, history):
34
+ try:
35
+ # Call your chatbot function here
36
+ message = [{"role": "user", "content": input}]
37
+ history = [{"role": "system", "content": "You are a helpful assistant."}]
38
+ messages = history + message
39
+
40
+ output = client.chat_completion(
41
+ messages=messages,
42
+ max_tokens=256,
43
+ temperature=0.7
44
+ )
45
+
46
+ history = history + [{"role": "assistant", "content": output.choices[0].message.content}]
47
+
48
+ return output.choices[0].message.content
49
+
50
+ except Exception as e:
51
+ raise Exception(str(e))
52
 
53
+ # Define the Gradio chat interface
54
+ def gradio_chat(input, history):
55
+ response = chatbot(input, history)
56
+ return response
57
 
58
+ # Define Gradio inputs and outputs
59
+ inputs = [gr.Textbox(lines=5, label="Input"), gr.Textbox(lines=5, label="History")]
60
+ output = gr.Textbox(label="Response")
61
 
62
+ # Create Gradio interface
63
+ gr.Interface(fn=gradio_chat, inputs=inputs, outputs=output, title="Chatbot").launch()
64