rhea2809 commited on
Commit
dc7f5ae
·
1 Parent(s): 31a8145

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from huggingface_hub import HfApi, ModelFilter
4
+
5
+ # Get the list of models from the Hugging Face Hub
6
+ api = HfApi()
7
+ models = api.list_models(filter=ModelFilter(tags="text-generation"))
8
+ models_names = [model.modelId for model in models]
9
+
10
+ # Dictionary to store loaded models and their pipelines
11
+ model_pipelines = {}
12
+
13
+ # Load a default model initially
14
+ default_model_name = "gia-project/gia2-small-untrained"
15
+ default_generator = pipeline("text-generation", model=default_model_name, trust_remote_code=True)
16
+ model_pipelines[default_model_name] = default_generator
17
+
18
+ def generate_text(model_name, input_text):
19
+ # Check if the selected model is already loaded
20
+ if model_name not in model_pipelines:
21
+ # Load the model and create a pipeline if it's not already loaded
22
+ generator = pipeline("text-generation", model=model_name, trust_remote_code=True)
23
+ model_pipelines[model_name] = generator
24
+
25
+ # Get the pipeline for the selected model and generate text
26
+ generator = model_pipelines[model_name]
27
+ generated_text = generator(input_text)[0]['generated_text']
28
+ return generated_text
29
+
30
+ # Define the Gradio interface
31
+ iface = gr.Interface(
32
+ fn=generate_text, # Function to be called on user input
33
+ inputs=[
34
+ gr.inputs.Dropdown(choices=models_names, label="Select Model", default=default_model_name), # Dropdown to select model
35
+ gr.inputs.Textbox(lines=5, label="Input Text") # Textbox for entering text
36
+ ],
37
+ outputs=gr.outputs.Textbox(label="Generated Text"), # Textbox to display the generated text
38
+ title="GIA Text Generation", # Title of the interface
39
+ )
40
+
41
+ # Launch the Gradio interface
42
+ iface.launch()