NotSoundRated commited on
Commit
916976b
·
verified ·
1 Parent(s): 81ff115

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -4
app.py CHANGED
@@ -1,6 +1,63 @@
1
  import gradio as gr
 
 
2
 
3
- gr.load(
4
- "models/microsoft/deberta-v3-small",
5
- provider="hf-inference",
6
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
 
5
+ # Define your actions
6
+ ACTIONS = [
7
+ "move_towards_player",
8
+ "stand_still",
9
+ "follow_player",
10
+ "run_away",
11
+ "jump",
12
+ "mine_block",
13
+ "place_block",
14
+ "attack",
15
+ "use_item",
16
+ "chat_only"
17
+ ]
18
+
19
+ # Load model and tokenizer (you'll fine-tune this on your specific task)
20
+ model_name = "models/microsoft/deberta-v3-small" # Starting point, you'll need to fine-tune
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(ACTIONS))
23
+
24
+ # This function will need to be trained/fine-tuned with your data
25
+ def predict_action(text):
26
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
27
+ outputs = model(**inputs)
28
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
29
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
30
+
31
+ # Return both the action and the confidence score
32
+ return {
33
+ "action": ACTIONS[predicted_class],
34
+ "confidence": probabilities[0][predicted_class].item(),
35
+ "all_actions": {ACTIONS[i]: probabilities[0][i].item() for i in range(len(ACTIONS))}
36
+ }
37
+
38
+ # Create the Gradio interface
39
+ def process_text(text):
40
+ result = predict_action(text)
41
+ return result["action"], result["confidence"], result["all_actions"]
42
+
43
+ # Define the Gradio interface
44
+ with gr.Blocks() as demo:
45
+ gr.Markdown("# Minecraft Action Predictor")
46
+ with gr.Row():
47
+ text_input = gr.Textbox(label="Character text")
48
+ action_output = gr.Textbox(label="Predicted action")
49
+ with gr.Row():
50
+ confidence_output = gr.Number(label="Confidence")
51
+ all_actions_output = gr.JSON(label="All action probabilities")
52
+
53
+ text_input.change(process_text, inputs=text_input, outputs=[action_output, confidence_output, all_actions_output])
54
+
55
+ gr.Interface(
56
+ fn=predict_action,
57
+ inputs=gr.Textbox(label="Character text"),
58
+ outputs=gr.JSON(label="Result"),
59
+ title="Minecraft Action Predictor API",
60
+ description="Predicts the best action based on character text"
61
+ )
62
+
63
+ demo.launch()