abhiimanyu commited on
Commit
b7f8795
·
verified ·
1 Parent(s): 0a7f020

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -72
app.py CHANGED
@@ -1,85 +1,95 @@
1
- import os
2
- import spaces
3
- import torch
4
  import gradio as gr
5
- from transformers import AutoTokenizer
6
 
7
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
 
8
 
9
- TITLE = "<h1><center>Learning Content Generator</center></h1>"
10
 
11
- PLACEHOLDER = """
12
- <center>
13
- <p>Enter the topic, description, and difficulty level to generate learning content.</p>
14
- </center>
15
- """
 
 
16
 
17
- from huggingface_hub import snapshot_download
18
- from pathlib import Path
 
 
 
 
 
19
 
20
- mistral_models_path = Path.home().joinpath('mistral_models', '7B-Mathstral')
21
- mistral_models_path.mkdir(parents=True, exist_ok=True)
22
-
23
- # Update the repo_id to the new model
24
- snapshot_download(repo_id="mistralai/Mathstral-7B-v0.1", allow_patterns=["params.json", "consolidated.safetensors"], local_dir=mistral_models_path)
25
-
26
- from mistral_inference.transformer import Transformer
27
- from mistral_inference.generate import generate
28
-
29
- from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
30
- from mistral_common.protocol.instruct.request import ChatCompletionRequest
31
-
32
- # Force device to "cpu"
33
- device = "cpu"
34
-
35
- # Use a pretrained tokenizer from Hugging Face's transformers
36
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mathstral-7B-v0.1")
37
-
38
- # Load the model and move it to the CPU
39
- model = Transformer.from_folder(mistral_models_path).to(device=device)
40
-
41
- @spaces.GPU()
42
- def generate_learning_content(topic: str, description: str, difficulty: str, temperature: float = 0.3, max_new_tokens: int = 1024):
43
- message = f"Generate learning content on the topic '{topic}' with the description: '{description}' and difficulty level: '{difficulty}'. Provide the content in paragraph format."
44
- conversation = [UserMessage(content=message)]
45
-
46
- completion_request = ChatCompletionRequest(messages=conversation)
47
-
48
- # Encode tokens and move to CPU
49
- tokens = tokenizer.encode(completion_request.messages[0].content, return_tensors="pt").to(device)
50
-
51
- # Generate output using the model on CPU
52
- out_tokens, _ = generate(
53
- [tokens],
54
- model,
55
- max_tokens=max_new_tokens,
56
  temperature=temperature,
57
- eos_id=tokenizer.eos_token_id
 
 
 
 
58
  )
59
-
60
- # Decode the output tokens to human-readable text
61
- result = tokenizer.decode(out_tokens[0], skip_special_tokens=True)
62
-
63
- return result
64
 
65
- with gr.Blocks(theme="ocean") as demo:
66
- gr.HTML(TITLE)
67
- topic_input = gr.Textbox(label="Topic", placeholder="Enter the topic for learning content.")
68
- description_input = gr.Textbox(label="Description", placeholder="Enter a brief description of the topic.")
69
- difficulty_input = gr.Textbox(label="Difficulty Level", placeholder="Enter the difficulty level (easy, medium, hard).")
70
-
71
- temperature_slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.3, label="Temperature")
72
- tokens_slider = gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max New Tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- output = gr.Textbox(label="Generated Learning Content")
75
 
76
- submit_button = gr.Button("Generate Content")
 
 
 
 
 
77
 
78
- submit_button.click(
79
- fn=generate_learning_content,
80
- inputs=[topic_input, description_input, difficulty_input, temperature_slider, tokens_slider],
81
- outputs=output,
82
- )
83
 
84
- if __name__ == "__main__":
85
- demo.launch()
 
 
1
+ from huggingface_hub import InferenceClient
 
 
2
  import gradio as gr
 
3
 
4
+ client = InferenceClient(
5
+ "mistralai/Mistral-7B-Instruct-v0.3"
6
+ )
7
 
 
8
 
9
+ def format_prompt(message, history):
10
+ prompt = "<s>"
11
+ for user_prompt, bot_response in history:
12
+ prompt += f"[INST] {user_prompt} [/INST]"
13
+ prompt += f" {bot_response}</s> "
14
+ prompt += f"[INST] {message} [/INST]"
15
+ return prompt
16
 
17
+ def generate(
18
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
+ ):
20
+ temperature = float(temperature)
21
+ if temperature < 1e-2:
22
+ temperature = 1e-2
23
+ top_p = float(top_p)
24
 
25
+ generate_kwargs = dict(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  temperature=temperature,
27
+ max_new_tokens=max_new_tokens,
28
+ top_p=top_p,
29
+ repetition_penalty=repetition_penalty,
30
+ do_sample=True,
31
+ seed=42,
32
  )
 
 
 
 
 
33
 
34
+ formatted_prompt = format_prompt(prompt, history)
35
+
36
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
37
+ output = ""
38
+
39
+ for response in stream:
40
+ output += response.token.text
41
+ yield output
42
+ return output
43
+
44
+
45
+ additional_inputs=[
46
+ gr.Slider(
47
+ label="Temperature",
48
+ value=0.9,
49
+ minimum=0.0,
50
+ maximum=1.0,
51
+ step=0.05,
52
+ interactive=True,
53
+ info="Higher values produce more diverse outputs",
54
+ ),
55
+ gr.Slider(
56
+ label="Max new tokens",
57
+ value=256,
58
+ minimum=0,
59
+ maximum=1048,
60
+ step=64,
61
+ interactive=True,
62
+ info="The maximum numbers of new tokens",
63
+ ),
64
+ gr.Slider(
65
+ label="Top-p (nucleus sampling)",
66
+ value=0.90,
67
+ minimum=0.0,
68
+ maximum=1,
69
+ step=0.05,
70
+ interactive=True,
71
+ info="Higher values sample more low-probability tokens",
72
+ ),
73
+ gr.Slider(
74
+ label="Repetition penalty",
75
+ value=1.2,
76
+ minimum=1.0,
77
+ maximum=2.0,
78
+ step=0.05,
79
+ interactive=True,
80
+ info="Penalize repeated tokens",
81
+ )
82
+ ]
83
 
 
84
 
85
+ gr.ChatInterface(
86
+ fn=generate,
87
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
88
+ additional_inputs=additional_inputs,
89
+ title="""Mistral 7B v0.3"""
90
+ ).launch(show_api=False)
91
 
 
 
 
 
 
92
 
93
+ gr.load("models/ehristoforu/dalle-3-xl-v2").launch()
94
+
95
+ gr.load("models/microsoft/Phi-3-mini-4k-instruct").launch()