syedmudassir16 commited on
Commit
daf54ea
·
verified ·
1 Parent(s): d2a398f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -19
app.py CHANGED
@@ -1,14 +1,14 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
- import torch
4
 
5
- # Load model and tokenizer
6
- model_name = "mistralai/Mistral-7B-Instruct-v0.1"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
9
 
10
- # Create a text-generation pipeline
11
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
 
 
12
 
13
  def format_prompt(message, history):
14
  fixed_prompt = """
@@ -83,18 +83,34 @@ def classify_mood(input_string):
83
  return word, True
84
  return None, False
85
 
86
- def generate(prompt, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  formatted_prompt = format_prompt(prompt, history)
88
- response = pipe(formatted_prompt, max_new_tokens=256, do_sample=True, temperature=0.7)[0]['generated_text']
89
-
90
- # Extract only the new generated text
91
- new_text = response[len(formatted_prompt):].strip()
92
-
93
- mood, is_classified = classify_mood(new_text)
94
- if is_classified:
95
- playlist_message = f"Playing {mood.capitalize()} playlist for you!"
96
- return playlist_message
97
- return new_text
 
98
 
99
  def chat(message, history):
100
  response = generate(message, history)
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import os
4
 
5
+ # Initialize the Inference Client
6
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
 
 
7
 
8
+ # Ensure you have set the HUGGINGFACE_TOKEN environment variable in your Hugging Face Space
9
+ HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
10
+ if HF_TOKEN is None:
11
+ raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable in your Hugging Face Space.")
12
 
13
  def format_prompt(message, history):
14
  fixed_prompt = """
 
83
  return word, True
84
  return None, False
85
 
86
+ def generate(
87
+ prompt, history, temperature=0.7, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
88
+ ):
89
+ temperature = float(temperature)
90
+ if temperature < 1e-2:
91
+ temperature = 1e-2
92
+ top_p = float(top_p)
93
+
94
+ generate_kwargs = dict(
95
+ temperature=temperature,
96
+ max_new_tokens=max_new_tokens,
97
+ top_p=top_p,
98
+ repetition_penalty=repetition_penalty,
99
+ do_sample=True,
100
+ )
101
+
102
  formatted_prompt = format_prompt(prompt, history)
103
+
104
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
105
+ output = ""
106
+
107
+ for response in stream:
108
+ output += response.token.text
109
+ mood, is_classified = classify_mood(output)
110
+ if is_classified:
111
+ playlist_message = f"Playing {mood.capitalize()} playlist for you!"
112
+ return playlist_message
113
+ return output
114
 
115
  def chat(message, history):
116
  response = generate(message, history)