iro-malta07 commited on
Commit
08b4bcb
·
verified ·
1 Parent(s): 4d6c65d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -13,24 +13,22 @@ id2label = model.config.id2label
13
  # Inference function
14
  def classify_text(text):
15
  if not text.endswith("."):
16
- text = text + "."
17
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
18
  with torch.no_grad():
19
  outputs = model(**inputs)
20
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
21
- pred_class = torch.argmax(probs, dim=1).item()
22
- confidence = probs[0, pred_class].item()
23
- label = id2label[pred_class]
24
- return f"Predicted Level: {label} (Confidence: {confidence:.2f})"
25
-
26
 
27
  # Gradio interface
28
  demo = gr.Interface(
29
  fn=classify_text,
30
  inputs=gr.Textbox(lines=4, placeholder="Schreibe etwas auf Deutsch..."),
31
- outputs=gr.Textbox(label="Language Level Prediction"),
32
  title="German Language Level Classifier",
33
- description="Enter German text and get the predicted CEFR level (A1 to C2).🚧 Working in progress.🚧"
34
  )
35
 
36
  # Launch app
 
13
  # Inference function
14
  def classify_text(text):
15
  if not text.endswith("."):
16
+ text += "."
17
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
18
  with torch.no_grad():
19
  outputs = model(**inputs)
20
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
21
+ # Create a dictionary of labels and their corresponding probabilities
22
+ confidences = {id2label[i]: float(probs[i]) for i in range(len(probs))}
23
+ return confidences
 
 
24
 
25
  # Gradio interface
26
  demo = gr.Interface(
27
  fn=classify_text,
28
  inputs=gr.Textbox(lines=4, placeholder="Schreibe etwas auf Deutsch..."),
29
+ outputs=gr.Label(num_top_classes=4),
30
  title="German Language Level Classifier",
31
+ description="Enter German text and get the predicted CEFR level (A1 to C2). 🚧 Work in progress. 🚧"
32
  )
33
 
34
  # Launch app