dicksinyass commited on
Commit
66043b9
·
verified ·
1 Parent(s): 998a82b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -26
app.py CHANGED
@@ -1,42 +1,125 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
  from huggingface_hub import login
4
 
5
- def load_model_and_tokenizer(hf_token=None):
6
- """Loads the tokenizer and model. Handles potential authentication."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  try:
8
  if hf_token:
9
- login(token=hf_token) # Use the provided token
10
 
11
- tokenizer = AutoTokenizer.from_pretrained(ryanmarten/Sky-T1-32B-Preview-5k-1-epoch, trust_remote_code=True)
12
- model = AutoModelForCausalLM.from_pretrained(ryanmarten/Sky-T1-32B-Preview-5k-1-epoch, trust_remote_code=True)
13
- return tokenizer, model, "Model and tokenizer loaded successfully!"
 
 
 
 
 
 
 
14
  except Exception as e:
15
- return None, None, f"Error loading model/tokenizer: {e}"
16
 
17
- def generate_text(prompt, hf_token=None):
18
- """Generates text based on a prompt."""
19
- if not hasattr(generate_text, "tokenizer") or not hasattr(generate_text, "model"):
20
- generate_text.tokenizer, generate_text.model, load_status = load_model_and_tokenizer(hf_token)
21
- if generate_text.tokenizer is None:
22
- return load_status # Return the error message if loading failed
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  try:
25
- inputs = generate_text.tokenizer(prompt, return_tensors="pt")
26
- outputs = generate_text.model.generate(**inputs, max_new_tokens=50)
27
- generated_text = generate_text.tokenizer.decode(outputs[0], skip_special_tokens=True)
28
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  except Exception as e:
30
- return f"Error during generation: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
 
 
 
 
32
 
33
- # Gradio Interface
34
- with gr.Blocks() as iface:
35
- hf_token_input = gr.Textbox(label="Hugging Face Token (Optional, for gated models)", type="password")
36
- prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
37
- output_textbox = gr.Textbox(label="Generated Text")
38
 
39
- generate_button = gr.Button("Generate")
40
- generate_button.click(fn=generate_text, inputs=[prompt_input, hf_token_input], outputs=output_textbox)
 
 
 
41
 
42
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline, Conversation
3
+ import torch
4
  from huggingface_hub import login
5
 
6
+ # --- Configuration ---
7
+
8
+ MODEL_CHOICES = [
9
+ "mistralai/Mistral-7B-Instruct-v0.2", # Good balance
10
+ "meta-llama/Llama-2-70b-chat-hf", # Higher quality, requires HF token, more resources
11
+ "mistralai/Mixtral-8x7B-Instruct-v0.1", # Potentially best quality, high resources
12
+ "codellama/CodeLlama-70b-Instruct-hf" # Best for code, high resources
13
+ ]
14
+
15
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
16
+
17
+
18
+ # --- Helper Functions ---
19
+
20
+ def load_model(model_name, hf_token=None):
21
+ """Loads the model and tokenizer, handling authentication."""
22
  try:
23
  if hf_token:
24
+ login(token=hf_token)
25
 
26
+ # Use a pipeline for easier interaction
27
+ pipe = pipeline(
28
+ "conversational",
29
+ model=model_name,
30
+ device=DEVICE, # Move to GPU if available
31
+ torch_dtype=torch.bfloat16, # Use bfloat16 for faster inference (if supported)
32
+ trust_remote_code=True, # Important for custom models
33
+ use_flash_attention_2=True, # Use flash attention if available
34
+ )
35
+ return pipe, "Model loaded successfully!"
36
  except Exception as e:
37
+ return None, f"Error loading model: {e}"
38
 
39
+
40
+ def generate_response(prompt, chat_history, model_name, hf_token=None):
41
+ """Generates a response using the conversational pipeline."""
42
+
43
+ # Use a dictionary to store loaded models for faster switching
44
+ if not hasattr(generate_response, "loaded_models"):
45
+ generate_response.loaded_models = {}
46
+
47
+ if model_name not in generate_response.loaded_models:
48
+ pipe, load_status = load_model(model_name, hf_token)
49
+ if pipe is None:
50
+ return load_status, chat_history
51
+ generate_response.loaded_models[model_name] = pipe
52
+ print(f"Model {model_name} loaded.") # Debugging message
53
+ else:
54
+ print(f"Using cached model {model_name}.") # Debugging message
55
+
56
+ pipe = generate_response.loaded_models[model_name]
57
 
58
  try:
59
+ # Convert Gradio chat history to transformers Conversation format
60
+ conversation = Conversation()
61
+ for user_message, bot_message in chat_history:
62
+ conversation.add_message({"role": "user", "content": user_message})
63
+ if bot_message: # Handle case where bot hasn't responded yet
64
+ conversation.add_message({"role": "assistant", "content": bot_message})
65
+ conversation.add_message({"role": "user", "content": prompt})
66
+
67
+ # Generation parameters (adjust these!)
68
+ generation_kwargs = {
69
+ "max_new_tokens": 512,
70
+ "do_sample": True,
71
+ "top_p": 0.95,
72
+ "temperature": 0.7,
73
+ "repetition_penalty": 1.1
74
+ }
75
+
76
+ # Generate the response
77
+ response = pipe(conversation, **generation_kwargs)
78
+
79
+ # Extract the bot's response from the Conversation object
80
+ bot_response = response.messages[-1]["content"]
81
+
82
+ # Update the chat history
83
+ chat_history.append((prompt, bot_response))
84
+
85
+ return "", chat_history
86
+
87
  except Exception as e:
88
+ return f"Error during generation: {e}", chat_history
89
+
90
+
91
+ # --- Gradio Interface ---
92
+
93
+ with gr.Blocks(title="Chat with a Powerful AI") as iface:
94
+ gr.Markdown(
95
+ """
96
+ # Chat with Different AI Models
97
+ This Space demonstrates a chatbot that allows you to select from different AI models.
98
+ Choose a model from the dropdown and start chatting!
99
+ """
100
+ )
101
+
102
+ model_selection = gr.Dropdown(
103
+ choices=MODEL_CHOICES,
104
+ value=MODEL_CHOICES[0], # Default model
105
+ label="Select Model",
106
+ info="Choose the AI model you want to chat with."
107
+ )
108
 
109
+ hf_token_input = gr.Textbox(
110
+ label="Hugging Face Token (Optional, for gated models)",
111
+ type="password",
112
+ placeholder="Enter your Hugging Face token (if required)",
113
+ )
114
 
115
+ chatbot = gr.Chatbot(label="Chat History", height=500) # Set a reasonable height
116
+ msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
117
+ clear = gr.ClearButton([msg, chatbot])
 
 
118
 
119
+ msg.submit(
120
+ generate_response,
121
+ [msg, chatbot, model_selection, hf_token_input],
122
+ [msg, chatbot],
123
+ )
124
 
125
  iface.launch()