aledraa commited on
Commit
1ac97be
·
verified ·
1 Parent(s): 383ed8c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+ import json
6
+ import random
7
+ from typing import List, Optional
8
+
9
+ app = FastAPI(title="Qwen Data Generator API")
10
+
11
+ # Load model and tokenizer
12
+ model_name = "Qwen/Qwen2.5-3B-Instruct"
13
+ print("Loading model...")
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ torch_dtype="auto",
17
+ device_map="auto"
18
+ )
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ print("Model loaded successfully!")
21
+
22
+ class GenerationRequest(BaseModel):
23
+ llm_commands: List[str]
24
+ batch_size: int = 50
25
+ seed: Optional[int] = None
26
+
27
+ class GenerationResponse(BaseModel):
28
+ success: bool
29
+ data: List[List[str]]
30
+ error: Optional[str] = None
31
+
32
+ def generate_data_prompt(llm_commands: List[str], batch_size: int) -> str:
33
+ columns_description = "\n".join([
34
+ f"Column {i+1}: {cmd}" for i, cmd in enumerate(llm_commands)
35
+ ])
36
+
37
+ return f"""Generate {batch_size} unique random rows of data based on these specifications:
38
+ {columns_description}
39
+
40
+ Requirements:
41
+ - Each row must be different and realistic
42
+ - Return ONLY a JSON array format: [["value1","value2"],["value1","value2"],...]
43
+ - No additional text, explanations, or formatting
44
+ - Values should be diverse and not repetitive
45
+
46
+ JSON Array:"""
47
+
48
+ @app.post("/generate", response_model=GenerationResponse)
49
+ async def generate_data(request: GenerationRequest):
50
+ try:
51
+ # Set seed for reproducibility if provided
52
+ if request.seed:
53
+ torch.manual_seed(request.seed)
54
+ random.seed(request.seed)
55
+
56
+ # Build prompt
57
+ prompt = generate_data_prompt(request.llm_commands, request.batch_size)
58
+
59
+ # Prepare messages for chat template
60
+ messages = [
61
+ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant that generates structured data."},
62
+ {"role": "user", "content": prompt}
63
+ ]
64
+
65
+ # Apply chat template
66
+ text = tokenizer.apply_chat_template(
67
+ messages,
68
+ tokenize=False,
69
+ add_generation_prompt=True
70
+ )
71
+
72
+ # Tokenize and generate
73
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
74
+
75
+ with torch.no_grad():
76
+ generated_ids = model.generate(
77
+ **model_inputs,
78
+ max_new_tokens=2048,
79
+ temperature=0.8,
80
+ do_sample=True,
81
+ pad_token_id=tokenizer.eos_token_id
82
+ )
83
+
84
+ # Decode response
85
+ generated_ids = [
86
+ output_ids[len(input_ids):]
87
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
88
+ ]
89
+
90
+ response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
91
+
92
+ # Parse JSON from response
93
+ try:
94
+ # Find JSON array in the response
95
+ start_idx = response_text.find('[')
96
+ end_idx = response_text.rfind(']') + 1
97
+
98
+ if start_idx == -1 or end_idx == 0:
99
+ raise ValueError("No JSON array found in response")
100
+
101
+ json_str = response_text[start_idx:end_idx]
102
+ parsed_data = json.loads(json_str)
103
+
104
+ # Validate data structure
105
+ if not isinstance(parsed_data, list):
106
+ raise ValueError("Response is not a list")
107
+
108
+ # Filter and validate rows
109
+ valid_rows = []
110
+ expected_columns = len(request.llm_commands)
111
+
112
+ for row in parsed_data:
113
+ if isinstance(row, list) and len(row) == expected_columns:
114
+ # Convert all values to strings
115
+ valid_rows.append([str(cell) for cell in row])
116
+
117
+ return GenerationResponse(
118
+ success=True,
119
+ data=valid_rows
120
+ )
121
+
122
+ except json.JSONDecodeError as e:
123
+ return GenerationResponse(
124
+ success=False,
125
+ data=[],
126
+ error=f"Failed to parse JSON: {str(e)}"
127
+ )
128
+ except Exception as e:
129
+ return GenerationResponse(
130
+ success=False,
131
+ data=[],
132
+ error=f"Data processing error: {str(e)}"
133
+ )
134
+
135
+ except Exception as e:
136
+ return GenerationResponse(
137
+ success=False,
138
+ data=[],
139
+ error=f"Generation error: {str(e)}"
140
+ )
141
+
142
+ @app.get("/health")
143
+ async def health_check():
144
+ return {"status": "healthy", "model": model_name}
145
+
146
+ if __name__ == "__main__":
147
+ import uvicorn
148
+ uvicorn.run(app, host="0.0.0.0", port=7860)