File size: 2,397 Bytes
376da64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import os
import re
import json
from flask import Flask, request, jsonify

app = Flask(__name__)

script_dir = os.path.dirname(os.path.abspath(__file__))
adapter_path = os.path.join(script_dir, "lora-playwright-adapter")
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Ensure padding token
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

model = PeftModel.from_pretrained(base_model, adapter_path)
model.eval()

# Example test goals (you can extend this to load from Excel/CSV)

def generate_action_sequence(test_goals):
    full_response = []
    for goal in test_goals:  
       
        prompt = f"Goal: {goal}\nReturn only one valid JSON array, no explanation.\nOutput:"

        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=150,
            pad_token_id=tokenizer.pad_token_id,
            top_p=1.0,
            repetition_penalty=1.2,
            do_sample=False  
            )

        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract JSON array part
        match = re.search(r'\[.*\]', response, re.DOTALL)
        if match:
            response_text = match.group(0)
            try:
                response_json = json.loads(response_text)
                full_response.extend(response_json) 
            except json.JSONDecodeError:
                print(f"Invalid JSON for goal: {goal}")
        else:
            print(f"No JSON found for goal: {goal}")
 
    return full_response

@app.route("/")
def health():
    return "OK", 200

@app.route("/generate", methods=["POST"])
def generate():
    data = request.get_json()
    test_goals = data.get("goals", [])
    result = generate_action_sequence(test_goals)
    return jsonify({"result": result})


if __name__ == "__main__":
    port = int(os.environ.get("PORT", 5000))
    app.run(host="0.0.0.0", port=port)