iro-malta07 commited on
Commit
487660b
·
verified ·
1 Parent(s): edccf87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -13
app.py CHANGED
@@ -3,24 +3,26 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
  # Load model and tokenizer from Hugging Face
6
- # model_name = "your-username/your-model-name" # replace with your model path
7
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- # model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
  # Use id2label from the config
11
- # id2label = model.config.id2label
12
 
13
  # Inference function
14
  def classify_text(text):
15
- # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
16
- #with torch.no_grad():
17
- # outputs = model(**inputs)
18
- # probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
19
- # pred_class = torch.argmax(probs, dim=1).item()
20
- #confidence = probs[0, pred_class].item()
21
- # label = id2label[pred_class]
22
- # return f"Predicted Level: {label} (Confidence: {confidence:.2f})"
23
- return f"Predicted Level: a2 (Confidence: 0.5)"
 
 
24
 
25
  # Gradio interface
26
  demo = gr.Interface(
 
3
  import torch
4
 
5
  # Load model and tokenizer from Hugging Face
6
+ model_name = "iro-malta07/distilbert-base-german-lang-level-class" # replace with your model path
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
  # Use id2label from the config
11
+ id2label = model.config.id2label
12
 
13
  # Inference function
14
  def classify_text(text):
15
+ if not text.endswith("."):
16
+ text = text + "."
17
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
18
+ with torch.no_grad():
19
+ outputs = model(**inputs)
20
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
21
+ pred_class = torch.argmax(probs, dim=1).item()
22
+ confidence = probs[0, pred_class].item()
23
+ label = id2label[pred_class]
24
+ return f"Predicted Level: {label} (Confidence: {confidence:.2f})"
25
+
26
 
27
  # Gradio interface
28
  demo = gr.Interface(