WarTitan2077 commited on
Commit
7431386
·
verified ·
1 Parent(s): 124c819

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -30
app.py CHANGED
@@ -1,42 +1,57 @@
 
1
  from transformers import pipeline
2
- from fastapi import FastAPI, Request
3
- from pydantic import BaseModel
4
- import uvicorn
5
 
6
- # Load classifier
7
- classifier = pipeline("text-classification", model="WarTitan2077/Number-Classifier", tokenizer="WarTitan2077/Number-Classifier", top_k=None)
 
 
 
 
 
8
 
9
- # Define labels (same as during training)
10
  labels = ["Symbolic", "Numeric", "Natural", "Integer", "Rational", "Irrational", "Real", "Prime", "Composite"]
11
-
12
- # Map label IDs back to names
13
  label_map = {f"LABEL_{i}": label for i, label in enumerate(labels)}
14
 
15
- # Input schema
16
- class Inputs(BaseModel):
17
- Input1: str
18
- Input2: str
19
- Input3: str
20
-
21
- # App
22
- app = FastAPI()
23
-
24
- @app.post("/predict")
25
- async def predict(data: Inputs):
26
  inputs = {
27
- "Input1": data.Input1,
28
- "Input2": data.Input2,
29
- "Input3": data.Input3
30
  }
31
 
32
- response = {}
33
 
34
- for key, value in inputs.items():
35
- if value.strip(): # only classify non-empty inputs
36
- result = classifier(value)
37
- preds = {label_map[item["label"]]: round(item["score"], 3) for item in result[0]}
38
- response[key] = preds
 
 
 
39
  else:
40
- response[key] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- return response
 
 
1
+ import gradio as gr
2
  from transformers import pipeline
 
 
 
3
 
4
+ # Load your fine-tuned multi-label classification model
5
+ classifier = pipeline(
6
+ "text-classification",
7
+ model="your-username/your-model-name",
8
+ tokenizer="your-username/your-model-name",
9
+ top_k=None
10
+ )
11
 
12
+ # Define labels in the correct order
13
  labels = ["Symbolic", "Numeric", "Natural", "Integer", "Rational", "Irrational", "Real", "Prime", "Composite"]
 
 
14
  label_map = {f"LABEL_{i}": label for i, label in enumerate(labels)}
15
 
16
+ # Define the prediction function
17
+ def classify_numbers(input1, input2, input3):
 
 
 
 
 
 
 
 
 
18
  inputs = {
19
+ "Input1": input1,
20
+ "Input2": input2,
21
+ "Input3": input3
22
  }
23
 
24
+ results = {}
25
 
26
+ for name, value in inputs.items():
27
+ if value.strip(): # only process non-empty input
28
+ output = classifier(value)[0] # list of dicts for each label
29
+ result = {
30
+ label_map[item["label"]]: round(item["score"], 3)
31
+ for item in output
32
+ }
33
+ results[name] = result
34
  else:
35
+ results[name] = {}
36
+
37
+ return results
38
+
39
+ # Define Gradio interface
40
+ inputs = [
41
+ gr.Textbox(label="Input 1"),
42
+ gr.Textbox(label="Input 2"),
43
+ gr.Textbox(label="Input 3")
44
+ ]
45
+
46
+ output = gr.JSON(label="Predictions")
47
+
48
+ demo = gr.Interface(
49
+ fn=classify_numbers,
50
+ inputs=inputs,
51
+ outputs=output,
52
+ title="Number Classifier",
53
+ description="Enter up to three inputs to classify them as Symbolic, Numeric, Prime, etc."
54
+ )
55
 
56
+ if __name__ == "__main__":
57
+ demo.launch()